mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow converting Flax to PyTorch by adding a "from_flax" keyword (#1900)
* from_flax * oops * oops * make style with pip install -e ".[dev]" * oops * now code quality happy 😋 * allow_patterns += FLAX_WEIGHTS_NAME * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * for test * bye bye is_flax_available() * oops * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * make style * add test * finihs Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
156
src/diffusers/models/modeling_pytorch_flax_utils.py
Normal file
156
src/diffusers/models/modeling_pytorch_flax_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch - Flax general utilities."""
|
||||
|
||||
from pickle import UnpicklingError
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.serialization import from_bytes
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
#####################
|
||||
# Flax => PyTorch #
|
||||
#####################
|
||||
|
||||
|
||||
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
|
||||
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
||||
try:
|
||||
with open(model_file, "rb") as flax_state_f:
|
||||
flax_state = from_bytes(None, flax_state_f.read())
|
||||
except UnpicklingError as e:
|
||||
try:
|
||||
with open(model_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||
" folder you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||
|
||||
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
||||
|
||||
|
||||
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
"""Load flax checkpoints in a PyTorch model"""
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
# check if we have bf16 weights
|
||||
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
||||
if any(is_type_bf16):
|
||||
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
|
||||
|
||||
# and bf16 is not fully supported in PT yet.
|
||||
logger.warning(
|
||||
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
||||
"before loading those in PyTorch model."
|
||||
)
|
||||
flax_state = jax.tree_util.tree_map(
|
||||
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
||||
)
|
||||
|
||||
pt_model.base_model_prefix = ""
|
||||
|
||||
flax_state_dict = flatten_dict(flax_state, sep=".")
|
||||
pt_model_dict = pt_model.state_dict()
|
||||
|
||||
# keep track of unexpected & missing keys
|
||||
unexpected_keys = []
|
||||
missing_keys = set(pt_model_dict.keys())
|
||||
|
||||
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
||||
flax_key_tuple_array = flax_key_tuple.split(".")
|
||||
|
||||
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
||||
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
||||
elif flax_key_tuple_array[-1] == "kernel":
|
||||
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||
flax_tensor = flax_tensor.T
|
||||
elif flax_key_tuple_array[-1] == "scale":
|
||||
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||
|
||||
if "time_embedding" not in flax_key_tuple_array:
|
||||
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
||||
flax_key_tuple_array[i] = (
|
||||
flax_key_tuple_string.replace("_0", ".0")
|
||||
.replace("_1", ".1")
|
||||
.replace("_2", ".2")
|
||||
.replace("_3", ".3")
|
||||
)
|
||||
|
||||
flax_key = ".".join(flax_key_tuple_array)
|
||||
|
||||
if flax_key in pt_model_dict:
|
||||
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
||||
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
else:
|
||||
# add weight to pytorch dict
|
||||
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
||||
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
||||
# remove from missing keys
|
||||
missing_keys.remove(flax_key)
|
||||
else:
|
||||
# weight is not expected by PyTorch model
|
||||
unexpected_keys.append(flax_key)
|
||||
|
||||
pt_model.load_state_dict(pt_model_dict)
|
||||
|
||||
# re-transform missing_keys to list
|
||||
missing_keys = list(missing_keys)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
"Some weights of the Flax model were not used when initializing the PyTorch model"
|
||||
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
||||
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
||||
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
||||
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
||||
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" FlaxBertForSequenceClassification model)."
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
||||
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
||||
" use it for predictions and inference."
|
||||
)
|
||||
|
||||
return pt_model
|
||||
@@ -30,6 +30,7 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
@@ -335,6 +336,8 @@ class ModelMixin(torch.nn.Module):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
from_flax (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a Flax checkpoint save file.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
@@ -375,6 +378,7 @@ class ModelMixin(torch.nn.Module):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
@@ -433,27 +437,10 @@ class ModelMixin(torch.nn.Module):
|
||||
# Load model
|
||||
|
||||
model_file = None
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if model_file is None:
|
||||
if from_flax:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -464,49 +451,6 @@ class ModelMixin(torch.nn.Module):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the parms from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
@@ -523,22 +467,121 @@ class ModelMixin(torch.nn.Module):
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the params from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
)
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
|
||||
@@ -37,6 +37,7 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
@@ -445,6 +446,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
@@ -470,11 +472,26 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# make sure we only download sub-folders and `diffusers` filenames
|
||||
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
allow_patterns += [
|
||||
WEIGHTS_NAME,
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
cls.config_name,
|
||||
]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = ["*.msgpack"]
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors"]
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
cls.config_name,
|
||||
]
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
@@ -704,7 +721,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers import (
|
||||
logging,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
@@ -816,6 +816,49 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
def test_from_flax_from_pt(self):
|
||||
pipe_pt = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
pipe_pt.to(torch_device)
|
||||
|
||||
if not is_flax_available():
|
||||
raise ImportError("Make sure flax is installed.")
|
||||
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe_pt.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, safety_checker=None, from_pt=True
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe_flax.save_pretrained(tmpdirname, params=params)
|
||||
pipe_pt_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None, from_flax=True)
|
||||
pipe_pt_2.to(torch_device)
|
||||
|
||||
prompt = "Hello"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_0 = pipe_pt(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images[0]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_1 = pipe_pt_2(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images[0]
|
||||
|
||||
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user