1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Modular] Consolidate load_default_components into load_components (#12217)

* update

* Apply style fixes

* update

* update

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Dhruv Nair
2025-08-28 16:25:02 +02:00
committed by GitHub
parent b2da59b197
commit ba0e732eb0
11 changed files with 59 additions and 66 deletions

View File

@@ -1418,7 +1418,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO:
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
# 2. do we need ConfigSpec? the are basically just key/val kwargs
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components()
class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
Base class for all Modular pipelines.
@@ -1488,7 +1488,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Components with default_creation_method="from_config" are created immediately, its specs are not included
in config dict and will not be saved in `modular_model_index.json`
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
`load_default_components()`/`load_components()`
`load_components()` (with or without specific component names)
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
@@ -1603,20 +1603,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default
return params
def load_default_components(self, **kwargs):
"""
Load from_pretrained components using the loading specs in the config dict.
Args:
**kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
"""
names = [
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
]
self.load_components(names=names, **kwargs)
@classmethod
@validate_hf_hub_args
def from_pretrained(
@@ -1770,8 +1756,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- non from_pretrained components are created during __init__ and registered as the object itself
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
loader.update_components(guider=guider_spec)
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
loader.load_default_components(names=["unet"])
- (from_pretrained) Components are loaded with the `load_components()` method: e.g.
loader.load_components(names=["unet"]) or loader.load_components() to load all default components
Args:
**kwargs: Keyword arguments where keys are component names and values are component objects.
@@ -2097,13 +2083,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
self.register_to_config(**config_to_register)
# YiYi TODO: support map for additional from_pretrained kwargs
# YiYi/Dhruv TODO: consolidate load_components and load_default_components?
def load_components(self, names: Union[List[str], str], **kwargs):
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
"""
Load selected components from specs.
Args:
names: List of component names to load; by default will not load any components
names: List of component names to load. If None, will load all components with
default_creation_method == "from_pretrained". If provided as a list or string, will load only the
specified components.
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
@@ -2111,7 +2098,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
`variant`, `revision`, etc.
"""
if isinstance(names, str):
if names is None:
names = [
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
]
elif isinstance(names, str):
names = [names]
elif not isinstance(names, list):
raise ValueError(f"Invalid type for names: {type(names)}")