mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -1409,7 +1409,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO:
|
||||
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
||||
# 2. do we need ConfigSpec? the are basically just key/val kwargs
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components()
|
||||
class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
@@ -1478,7 +1478,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included
|
||||
in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
|
||||
`load_default_components()`/`load_components()`
|
||||
`load_components()` (with or without specific component names)
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
|
||||
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
@@ -1548,12 +1548,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
|
||||
"""
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
self.load_components(names=names, **kwargs)
|
||||
# Consolidated into load_components - just call it without names parameter
|
||||
logger.warning("`load_default_components` is deprecated. Please use `load_components()` instead")
|
||||
self.load_components(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
@@ -1682,8 +1679,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
|
||||
loader.update_components(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
|
||||
loader.load_default_components(names=["unet"])
|
||||
- (from_pretrained) Components are loaded with the `load_components()` method: e.g.
|
||||
loader.load_components(names=["unet"]) or loader.load_components() to load all default components
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
@@ -1995,13 +1992,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
# YiYi/Dhruv TODO: consolidate load_components and load_default_components?
|
||||
def load_components(self, names: Union[List[str], str], **kwargs):
|
||||
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
|
||||
Args:
|
||||
names: List of component names to load; by default will not load any components
|
||||
names: List of component names to load. If None, will load all components with
|
||||
default_creation_method == "from_pretrained". If provided as a list or string,
|
||||
will load only the specified components.
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
@@ -2009,7 +2007,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
`variant`, `revision`, etc.
|
||||
"""
|
||||
|
||||
if isinstance(names, str):
|
||||
if names is None:
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
elif not isinstance(names, list):
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
Reference in New Issue
Block a user