From f73ed179610653bf100215a54ca2c8a3cba91cf0 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Thu, 12 Jan 2023 22:00:35 +0300 Subject: [PATCH] Allow converting Flax to PyTorch by adding a "from_flax" keyword (#1900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * from_flax * oops * oops * make style with pip install -e ".[dev]" * oops * now code quality happy 😋 * allow_patterns += FLAX_WEIGHTS_NAME * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen * Update src/diffusers/models/modeling_utils.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen * for test * bye bye is_flax_available() * oops * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca * make style * add test * finihs Co-authored-by: Patrick von Platen Co-authored-by: Pedro Cuenca --- .../models/modeling_pytorch_flax_utils.py | 156 ++++++++++++++ src/diffusers/models/modeling_utils.py | 195 +++++++++++------- src/diffusers/pipelines/pipeline_utils.py | 28 ++- tests/test_pipelines.py | 45 +++- 4 files changed, 345 insertions(+), 79 deletions(-) create mode 100644 src/diffusers/models/modeling_pytorch_flax_utils.py diff --git a/src/diffusers/models/modeling_pytorch_flax_utils.py b/src/diffusers/models/modeling_pytorch_flax_utils.py new file mode 100644 index 0000000000..46bb774059 --- /dev/null +++ b/src/diffusers/models/modeling_pytorch_flax_utils.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" + +from pickle import UnpicklingError + +import numpy as np + +import jax +import jax.numpy as jnp +from flax.serialization import from_bytes +from flax.traverse_util import flatten_dict + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +##################### +# Flax => PyTorch # +##################### + + +# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 +def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): + try: + with open(model_file, "rb") as flax_state_f: + flax_state = from_bytes(None, flax_state_f.read()) + except UnpicklingError as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + + return load_flax_weights_in_pytorch_model(pt_model, flax_state) + + +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + + # check if we have bf16 weights + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + if any(is_type_bf16): + # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 + + # and bf16 is not fully supported in PT yet. + logger.warning( + "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " + "before loading those in PyTorch model." + ) + flax_state = jax.tree_util.tree_map( + lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state + ) + + pt_model.base_model_prefix = "" + + flax_state_dict = flatten_dict(flax_state, sep=".") + pt_model_dict = pt_model.state_dict() + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + flax_key_tuple_array = flax_key_tuple.split(".") + + if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) + elif flax_key_tuple_array[-1] == "kernel": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + flax_tensor = flax_tensor.T + elif flax_key_tuple_array[-1] == "scale": + flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] + + if "time_embedding" not in flax_key_tuple_array: + for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): + flax_key_tuple_array[i] = ( + flax_key_tuple_string.replace("_0", ".0") + .replace("_1", ".1") + .replace("_2", ".2") + .replace("_3", ".3") + ) + + flax_key = ".".join(flax_key_tuple_array) + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" + f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " FlaxBertForSequenceClassification model)." + ) + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + + return pt_model diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 91c44973b3..26e913c4b9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -30,6 +30,7 @@ from .. import __version__ from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, @@ -335,6 +336,8 @@ class ModelMixin(torch.nn.Module): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. @@ -375,6 +378,7 @@ class ModelMixin(torch.nn.Module): cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) @@ -433,27 +437,10 @@ class ModelMixin(torch.nn.Module): # Load model model_file = None - if is_safetensors_available(): - try: - model_file = cls._get_model_file( - pretrained_model_name_or_path, - weights_name=SAFETENSORS_WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - except: - pass - if model_file is None: + if from_flax: model_file = cls._get_model_file( pretrained_model_name_or_path, - weights_name=WEIGHTS_NAME, + weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -464,49 +451,6 @@ class ModelMixin(torch.nn.Module): subfolder=subfolder, user_agent=user_agent, ) - - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) - model = cls.from_config(config, **unused_kwargs) - - # if device_map is Non,e load the state dict on move the params from meta device to the cpu - if device_map is None: - param_device = "cpu" - state_dict = load_state_dict(model_file) - # move the parms from meta device to cpu - for param_name, param in state_dict.items(): - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - if accepts_dtype: - set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) - else: - set_module_tensor_to_device(model, param_name, param_device, value=param) - else: # else let accelerate handle loading and dispatching. - # Load weights and dispatch according to the device_map - # by deafult the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: config, unused_kwargs = cls.load_config( config_path, cache_dir=cache_dir, @@ -523,22 +467,121 @@ class ModelMixin(torch.nn.Module): ) model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file) + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_safetensors_available(): + try: + model_file = cls._get_model_file( + pretrained_model_name_or_path, + weights_name=SAFETENSORS_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + except: + pass + if model_file is None: + model_file = cls._get_model_file( + pretrained_model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file) + # move the params from meta device to cpu + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature(set_module_tensor_to_device).parameters.keys() + ) + if accepts_dtype: + set_module_tensor_to_device( + model, param_name, param_device, value=param, dtype=torch_dtype + ) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by deafult the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 854a003e89..ea28ac875f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -37,6 +37,7 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, @@ -445,6 +446,7 @@ class DiffusionPipeline(ConfigMixin): local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) torch_dtype = kwargs.pop("torch_dtype", None) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) @@ -470,11 +472,26 @@ class DiffusionPipeline(ConfigMixin): # make sure we only download sub-folders and `diffusers` filenames folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] - allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + allow_patterns += [ + WEIGHTS_NAME, + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + ONNX_WEIGHTS_NAME, + cls.config_name, + ] # make sure we don't download flax weights ignore_patterns = ["*.msgpack"] + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors"] + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + cls.config_name, + ] + if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] @@ -704,7 +721,14 @@ class DiffusionPipeline(ConfigMixin): # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if from_flax: + loading_kwargs["from_flax"] = True + + # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` + if not (from_flax and is_transformers_model): + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + else: + loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index b10a145ea3..cf8ef8d6da 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -47,7 +47,7 @@ from diffusers import ( logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu from parameterized import parameterized from PIL import Image @@ -816,6 +816,49 @@ class PipelineSlowTests(unittest.TestCase): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) + def test_from_flax_from_pt(self): + pipe_pt = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe_pt.to(torch_device) + + if not is_flax_available(): + raise ImportError("Make sure flax is installed.") + + from diffusers import FlaxStableDiffusionPipeline + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe_pt.save_pretrained(tmpdirname) + + pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained( + tmpdirname, safety_checker=None, from_pt=True + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe_flax.save_pretrained(tmpdirname, params=params) + pipe_pt_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None, from_flax=True) + pipe_pt_2.to(torch_device) + + prompt = "Hello" + + generator = torch.manual_seed(0) + image_0 = pipe_pt( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + ).images[0] + + generator = torch.manual_seed(0) + image_1 = pipe_pt_2( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + ).images[0] + + assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass" + @nightly @require_torch_gpu