mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[refactor] DiffusionPipeline.download (#9557)
* update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -838,3 +838,108 @@ def _update_init_kwargs_with_connected_pipeline(
|
||||
)
|
||||
|
||||
return init_kwargs
|
||||
|
||||
|
||||
def _get_custom_components_and_folders(
|
||||
pretrained_model_name: str,
|
||||
config_dict: Dict[str, Any],
|
||||
filenames: Optional[List[str]] = None,
|
||||
variant_filenames: Optional[List[str]] = None,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
config_dict = config_dict.copy()
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# optionally create a custom component <> custom file mapping
|
||||
custom_components = {}
|
||||
for component in folder_names:
|
||||
module_candidate = config_dict[component][0]
|
||||
|
||||
if module_candidate is None or not isinstance(module_candidate, str):
|
||||
continue
|
||||
|
||||
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
||||
candidate_file = f"{component}/{module_candidate}.py"
|
||||
|
||||
if candidate_file in filenames:
|
||||
custom_components[component] = module_candidate
|
||||
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
||||
raise ValueError(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
return custom_components, folder_names
|
||||
|
||||
|
||||
def _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names: List[str],
|
||||
model_filenames: List[str],
|
||||
variant_filenames: List[str],
|
||||
use_safetensors: bool,
|
||||
from_flax: bool,
|
||||
allow_pickle: bool,
|
||||
use_onnx: bool,
|
||||
is_onnx: bool,
|
||||
variant: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
if (
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
||||
)
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
|
||||
f"expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
|
||||
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
|
||||
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
|
||||
f"your folder structure."
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
@@ -71,15 +71,16 @@ from .pipeline_loading_utils import (
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_components_and_folders,
|
||||
_get_custom_pipeline_class,
|
||||
_get_final_device_map,
|
||||
_get_ignore_patterns,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
is_safetensors_compatible,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
variant_compatible_siblings,
|
||||
@@ -1298,44 +1299,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# optionally create a custom component <> custom file mapping
|
||||
custom_components = {}
|
||||
for component in folder_names:
|
||||
module_candidate = config_dict[component][0]
|
||||
|
||||
if module_candidate is None or not isinstance(module_candidate, str):
|
||||
continue
|
||||
|
||||
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
||||
candidate_file = f"{component}/{module_candidate}.py"
|
||||
|
||||
if candidate_file in filenames:
|
||||
custom_components[component] = module_candidate
|
||||
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
||||
raise ValueError(
|
||||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
||||
)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
raise ValueError(error_message)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
# if the whole pipeline is cached we don't have to ping the Hub
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.22.0"):
|
||||
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
||||
|
||||
custom_components, folder_names = _get_custom_components_and_folders(
|
||||
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
||||
)
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
@@ -1395,49 +1370,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
||||
passed_components = [k for k in expected_components if k in kwargs]
|
||||
|
||||
if (
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
||||
)
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if (
|
||||
len(safetensors_variant_filenames) > 0
|
||||
and safetensors_model_filenames != safetensors_variant_filenames
|
||||
):
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
# retrieve all patterns that should not be downloaded and error out when needed
|
||||
ignore_patterns = _get_ignore_patterns(
|
||||
passed_components,
|
||||
model_folder_names,
|
||||
model_filenames,
|
||||
variant_filenames,
|
||||
use_safetensors,
|
||||
from_flax,
|
||||
allow_pickle,
|
||||
use_onnx,
|
||||
pipeline_class._is_onnx,
|
||||
variant,
|
||||
)
|
||||
|
||||
# Don't download any objects that are passed
|
||||
allow_patterns = [
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
|
||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user