mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Modular] Consolidate load_default_components into load_components (#12217)
* update * Apply style fixes * update * update --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1418,7 +1418,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO:
|
||||
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
||||
# 2. do we need ConfigSpec? the are basically just key/val kwargs
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components()
|
||||
class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
@@ -1488,7 +1488,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included
|
||||
in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
|
||||
`load_default_components()`/`load_components()`
|
||||
`load_components()` (with or without specific component names)
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
|
||||
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
@@ -1603,20 +1603,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
"""
|
||||
Load from_pretrained components using the loading specs in the config dict.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
|
||||
"""
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
self.load_components(names=names, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -1770,8 +1756,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
|
||||
loader.update_components(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
|
||||
loader.load_default_components(names=["unet"])
|
||||
- (from_pretrained) Components are loaded with the `load_components()` method: e.g.
|
||||
loader.load_components(names=["unet"]) or loader.load_components() to load all default components
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
@@ -2097,13 +2083,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
# YiYi/Dhruv TODO: consolidate load_components and load_default_components?
|
||||
def load_components(self, names: Union[List[str], str], **kwargs):
|
||||
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
|
||||
Args:
|
||||
names: List of component names to load; by default will not load any components
|
||||
names: List of component names to load. If None, will load all components with
|
||||
default_creation_method == "from_pretrained". If provided as a list or string, will load only the
|
||||
specified components.
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
@@ -2111,7 +2098,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
`variant`, `revision`, etc.
|
||||
"""
|
||||
|
||||
if isinstance(names, str):
|
||||
if names is None:
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
elif not isinstance(names, list):
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
Reference in New Issue
Block a user