1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] Tear apart from_pretrained() of DiffusionPipeline (#8967)

* break from_pretrained part i.

* part ii.

* init_kwargs

* remove _fetch_init_kwargs

* type annotation

* dtyle

* switch to _check_and_update_init_kwargs_for_missing_modules.

* remove _check_and_update_init_kwargs_for_missing_modules.

* use pipeline_loading_kwargs.

* remove _determine_current_device_map.

* remove _filter_null_components.

* device_map fix.

* fix _update_init_kwargs_with_connected_pipeline.

* better handle custom pipeline.

* explain _maybe_raise_warning_for_inpainting.

* add example for model variant.

* fix
This commit is contained in:
Sayak Paul
2024-08-22 06:50:57 +05:30
committed by GitHub
parent 43f1090a0f
commit 32d6492c7b
2 changed files with 126 additions and 92 deletions

View File

@@ -22,7 +22,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from huggingface_hub import model_info
from huggingface_hub import ModelCard, model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
@@ -33,6 +33,7 @@ from ..utils import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
is_peft_available,
@@ -746,3 +747,92 @@ def _fetch_class_library_tuple(module):
class_name = not_compiled_module.__class__.__name__
return (library, class_name)
def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
model_variants = {}
if variant is not None:
for folder in os.listdir(folder):
folder_path = os.path.join(folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config
variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
if variant_exists:
model_variants[folder] = variant
return model_variants
def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
custom_class_name = None
if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(folder, f"{config['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
custom_class_name = config["_class_name"][1]
return custom_pipeline, custom_class_name
def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy
deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
def _update_init_kwargs_with_connected_pipeline(
init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
) -> dict:
from .pipeline_utils import DiffusionPipeline
modelcard = ModelCard.load(os.path.join(folder, "README.md"))
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
# We don't scheduler argument to match the existing logic:
# https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
for k in pipeline_loading_kwargs:
if "scheduler" in k:
_ = pipeline_loading_kwargs_cp.pop(k)
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
return connected_passed_kwargs
connected_pipes = {
prefix: DiffusionPipeline.from_pretrained(
repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
)
for prefix, repo_id in connected_pipes.items()
if repo_id is not None
}
for prefix, connected_pipe in connected_pipes.items():
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
init_kwargs.update(
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
return init_kwargs

View File

@@ -75,7 +75,11 @@ from .pipeline_loading_utils import (
_get_custom_pipeline_class,
_get_final_device_map,
_get_pipeline_class,
_identify_model_variants,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
is_safetensors_compatible,
load_sub_model,
maybe_raise_or_warn,
@@ -622,6 +626,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler
```
"""
# Copy the kwargs to re-use during loading connected pipeline.
kwargs_copied = kwargs.copy()
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -722,33 +729,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
config_dict.pop("_ignore_files", None)
# 2. Define which model components should load variants
# We retrieve the information by matching whether variant
# model checkpoints exist in the subfolders
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`.
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
)
pipeline_class = _get_pipeline_class(
cls,
config_dict,
config=config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
class_name=custom_class_name,
@@ -760,23 +753,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy
deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
_maybe_raise_warning_for_inpainting(
pipeline_class=pipeline_class,
pretrained_model_name_or_path=pretrained_model_name_or_path,
config=config_dict,
)
# 4. Define expected modules given pipeline signature
# and define non-None initialized modules (=`init_kwargs`)
@@ -787,7 +770,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs and make sure that optional component modules are filtered out
@@ -847,6 +829,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map = None
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
@@ -854,15 +837,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
else:
current_device_map = None
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
# 7.2 Define all importable classes
# 7.3 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 7.3 Use passed sub model or load class_name from library_name
# 7.4 Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
@@ -900,56 +883,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 8. Handle connected pipelines.
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
load_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"torch_dtype": torch_dtype,
"custom_pipeline": custom_pipeline,
"custom_revision": custom_revision,
"provider": provider,
"sess_options": sess_options,
"device_map": device_map,
"max_memory": max_memory,
"offload_folder": offload_folder,
"offload_state_dict": offload_state_dict,
"low_cpu_mem_usage": low_cpu_mem_usage,
"variant": variant,
"use_safetensors": use_safetensors,
}
init_kwargs = _update_init_kwargs_with_connected_pipeline(
init_kwargs=init_kwargs,
passed_pipe_kwargs=passed_pipe_kwargs,
passed_class_objs=passed_class_obj,
folder=cached_folder,
**kwargs_copied,
)
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
return connected_passed_kwargs
connected_pipes = {
prefix: DiffusionPipeline.from_pretrained(
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
)
for prefix, repo_id in connected_pipes.items()
if repo_id is not None
}
for prefix, connected_pipe in connected_pipes.items():
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
init_kwargs.update(
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
# 8. Potentially add passed objects if expected
# 9. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components