From 32d6492c7bebadca5603f7e8705956af70ef259c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 22 Aug 2024 06:50:57 +0530 Subject: [PATCH] [Core] Tear apart `from_pretrained()` of `DiffusionPipeline` (#8967) * break from_pretrained part i. * part ii. * init_kwargs * remove _fetch_init_kwargs * type annotation * dtyle * switch to _check_and_update_init_kwargs_for_missing_modules. * remove _check_and_update_init_kwargs_for_missing_modules. * use pipeline_loading_kwargs. * remove _determine_current_device_map. * remove _filter_null_components. * device_map fix. * fix _update_init_kwargs_with_connected_pipeline. * better handle custom pipeline. * explain _maybe_raise_warning_for_inpainting. * add example for model variant. * fix --- .../pipelines/pipeline_loading_utils.py | 92 ++++++++++++- src/diffusers/pipelines/pipeline_utils.py | 126 +++++------------- 2 files changed, 126 insertions(+), 92 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a8c23adead..d72292b844 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -22,7 +22,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch -from huggingface_hub import model_info +from huggingface_hub import ModelCard, model_info from huggingface_hub.utils import validate_hf_hub_args from packaging import version @@ -33,6 +33,7 @@ from ..utils import ( ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, + deprecate, get_class_from_dynamic_module, is_accelerate_available, is_peft_available, @@ -746,3 +747,92 @@ def _fetch_class_library_tuple(module): class_name = not_compiled_module.__class__.__name__ return (library, class_name) + + +def _identify_model_variants(folder: str, variant: str, config: dict) -> dict: + model_variants = {} + if variant is not None: + for folder in os.listdir(folder): + folder_path = os.path.join(folder, folder) + is_folder = os.path.isdir(folder_path) and folder in config + variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + if variant_exists: + model_variants[folder] = variant + return model_variants + + +def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline): + custom_class_name = None + if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")): + custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py") + elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile( + os.path.join(folder, f"{config['_class_name'][0]}.py") + ): + custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py") + custom_class_name = config["_class_name"][1] + + return custom_pipeline, custom_class_name + + +def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict): + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + + +def _update_init_kwargs_with_connected_pipeline( + init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs +) -> dict: + from .pipeline_utils import DiffusionPipeline + + modelcard = ModelCard.load(os.path.join(folder, "README.md")) + connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} + + # We don't scheduler argument to match the existing logic: + # https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14 + pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy() + if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1: + for k in pipeline_loading_kwargs: + if "scheduler" in k: + _ = pipeline_loading_kwargs_cp.pop(k) + + def get_connected_passed_kwargs(prefix): + connected_passed_class_obj = { + k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix + } + connected_passed_pipe_kwargs = { + k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix + } + + connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} + return connected_passed_kwargs + + connected_pipes = { + prefix: DiffusionPipeline.from_pretrained( + repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix) + ) + for prefix, repo_id in connected_pipes.items() + if repo_id is not None + } + + for prefix, connected_pipe in connected_pipes.items(): + # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" + init_kwargs.update( + {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} + ) + + return init_kwargs diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 631776f250..aa6da17edf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -75,7 +75,11 @@ from .pipeline_loading_utils import ( _get_custom_pipeline_class, _get_final_device_map, _get_pipeline_class, + _identify_model_variants, + _maybe_raise_warning_for_inpainting, + _resolve_custom_pipeline_and_cls, _unwrap_model, + _update_init_kwargs_with_connected_pipeline, is_safetensors_compatible, load_sub_model, maybe_raise_or_warn, @@ -622,6 +626,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): >>> pipeline.scheduler = scheduler ``` """ + # Copy the kwargs to re-use during loading connected pipeline. + kwargs_copied = kwargs.copy() + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -722,33 +729,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): config_dict.pop("_ignore_files", None) # 2. Define which model components should load variants - # We retrieve the information by matching whether variant - # model checkpoints exist in the subfolders - model_variants = {} - if variant is not None: - for folder in os.listdir(cached_folder): - folder_path = os.path.join(cached_folder, folder) - is_folder = os.path.isdir(folder_path) and folder in config_dict - variant_exists = is_folder and any( - p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) - ) - if variant_exists: - model_variants[folder] = variant + # We retrieve the information by matching whether variant model checkpoints exist in the subfolders. + # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` + # with variant being `"fp16"`. + model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it - custom_class_name = None - if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")): - custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py") - elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile( - os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") - ): - custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") - custom_class_name = config_dict["_class_name"][1] - + custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls( + folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline + ) pipeline_class = _get_pipeline_class( cls, - config_dict, + config=config_dict, load_connected_pipeline=load_connected_pipeline, custom_pipeline=custom_pipeline, class_name=custom_class_name, @@ -760,23 +753,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): raise NotImplementedError("`device_map` is not yet supported for connected pipelines.") # DEPRECATED: To be removed in 1.0.0 - if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( - version.parse(config_dict["_diffusers_version"]).base_version - ) <= version.parse("0.5.1"): - from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy - - pipeline_class = StableDiffusionInpaintPipelineLegacy - - deprecation_message = ( - "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" - " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" - " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" - f" checkpoint {pretrained_model_name_or_path} to the format of" - " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" - " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." - ) - deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded + # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1. + _maybe_raise_warning_for_inpainting( + pipeline_class=pipeline_class, + pretrained_model_name_or_path=pretrained_model_name_or_path, + config=config_dict, + ) # 4. Define expected modules given pipeline signature # and define non-None initialized modules (=`init_kwargs`) @@ -787,7 +770,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) # define init kwargs and make sure that optional component modules are filtered out @@ -847,6 +829,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): # 7. Load each module in the pipeline current_device_map = None for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): + # 7.1 device_map shenanigans if final_device_map is not None and len(final_device_map) > 0: component_device = final_device_map.get(name, None) if component_device is not None: @@ -854,15 +837,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): else: current_device_map = None - # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names class_name = class_name[4:] if class_name.startswith("Flax") else class_name - # 7.2 Define all importable classes + # 7.3 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None - # 7.3 Use passed sub model or load class_name from library_name + # 7.4 Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class @@ -900,56 +883,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + # 8. Handle connected pipelines. if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): - modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) - connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} - load_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "local_files_only": local_files_only, - "token": token, - "revision": revision, - "torch_dtype": torch_dtype, - "custom_pipeline": custom_pipeline, - "custom_revision": custom_revision, - "provider": provider, - "sess_options": sess_options, - "device_map": device_map, - "max_memory": max_memory, - "offload_folder": offload_folder, - "offload_state_dict": offload_state_dict, - "low_cpu_mem_usage": low_cpu_mem_usage, - "variant": variant, - "use_safetensors": use_safetensors, - } + init_kwargs = _update_init_kwargs_with_connected_pipeline( + init_kwargs=init_kwargs, + passed_pipe_kwargs=passed_pipe_kwargs, + passed_class_objs=passed_class_obj, + folder=cached_folder, + **kwargs_copied, + ) - def get_connected_passed_kwargs(prefix): - connected_passed_class_obj = { - k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix - } - connected_passed_pipe_kwargs = { - k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix - } - - connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} - return connected_passed_kwargs - - connected_pipes = { - prefix: DiffusionPipeline.from_pretrained( - repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) - ) - for prefix, repo_id in connected_pipes.items() - if repo_id is not None - } - - for prefix, connected_pipe in connected_pipes.items(): - # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" - init_kwargs.update( - {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} - ) - - # 8. Potentially add passed objects if expected + # 9. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components