1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update from_componeenet, update_component

This commit is contained in:
yiyixuxu
2025-07-07 09:51:04 +02:00
parent 179d6d958b
commit 5af003a9e1
2 changed files with 257 additions and 68 deletions

View File

@@ -47,7 +47,8 @@ from .modular_pipeline_utils import (
format_intermediates_short,
make_doc_string,
)
from huggingface_hub import create_repo
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
if is_accelerate_available():
import accelerate
@@ -1665,7 +1666,45 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
**kwargs,
):
"""
Initialize the loader with a list of component specs and config specs.
Initialize a ModularPipeline instance.
This method sets up the pipeline by:
1. creating default pipeline blocks if not provided
2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs)
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided
4. create defaultfrom_config components and register everything
Args:
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
default blocks based on the pipeline class name.
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file.
components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies.
collection: Optional collection name for organizing components in the ComponentsManager.
**kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
Examples:
```python
# Initialize with custom blocks
pipeline = ModularPipeline(blocks=my_custom_blocks)
# Initialize from pretrained configuration
pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline")
# Initialize with components manager
pipeline = ModularPipeline(
blocks=my_blocks,
components_manager=ComponentsManager(),
collection="my_collection"
)
```
Notes:
- If blocks is None, the method will try to find default blocks based on the pipeline class name
- Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json`
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()`
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained`
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict
"""
if blocks is None:
blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
@@ -1715,6 +1754,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
@property
def default_call_parameters(self) -> Dict[str, Any]:
"""
Returns:
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
@@ -1722,7 +1765,40 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
Execute the pipeline by running the pipeline blocks with the given inputs.
Args:
state (`PipelineState`, optional):
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement.
output (`str` or `List[str]`, optional):
Optional specification of what to return:
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`)
Examples:
```python
# Get complete pipeline state
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
print(state.intermediates) # All intermediate outputs
# Get specific output
image = pipeline(prompt="A beautiful sunset", output="image")
# Get multiple specific outputs
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
image, latents = results["image"], results["latents"]
# Continue from previous state
state = pipeline(prompt="A beautiful sunset")
new_state = pipeline(state=state, output="image") # Continue processing
```
Returns:
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`)
"""
if state is None:
state = PipelineState()
@@ -1776,6 +1852,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
raise ValueError(f"Output '{output}' is not a valid output type")
def load_default_components(self, **kwargs):
"""
Load from_pretrained components using the loading specs in the config dict.
Args:
**kwargs: Additional arguments passed to `load_components()` method
"""
names = [
name
for name in self._component_specs.keys()
@@ -1793,6 +1875,19 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
collection: Optional[str] = None,
**kwargs,
):
"""
Load a ModularPipeline from a huggingface hub repo.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file.
trust_remote_code (`bool`, optional):
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path`
components_manager (`ComponentsManager`, optional):
ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies.
collection (`str`, optional):`
Collection name for organizing components in the ComponentsManager.
"""
from ..pipelines.pipeline_loading_utils import _get_pipeline_class
try:
@@ -1830,14 +1925,50 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
# YiYi TODO:
# 1. should support save some components too! currently only modular_model_index.json is saved
# 2. maybe order the json file to make it more readable: configs first, then components
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
"""
Save the pipeline to a directory. It does not save components, you need to save them separately.
Args:
save_directory (`str` or `os.PathLike`):
Path to the directory where the pipeline will be saved.
push_to_hub (`bool`, optional):
Whether to push the pipeline to the huggingface hub.
**kwargs: Additional arguments passed to `save_config()` method
"""
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
self.save_config(save_directory=save_directory)
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
@property
def doc(self):
"""
Returns:
- The docstring of the pipeline blocks
"""
return self.blocks.doc
def register_components(self, **kwargs):
@@ -1846,25 +1977,24 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
This method is responsible for:
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
2. Updates the modular_model_index.json configuration for serialization (only for from_pretrained components)
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components)
3. Adds components to the component manager if one is attached (only for from_pretrained components)
This method is called when:
- Components are first initialized in __init__:
- from_pretrained components not loaded during __init__ so they are registered as None;
- non from_pretrained components are created during __init__ and registered as the object itself
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or
loader.update(guider=guider_spec)
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"])
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
loader.update_components(guider=guider_spec)
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"])
Args:
**kwargs: Keyword arguments where keys are component names and values are component objects.
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
Notes:
- Components must be created from ComponentSpec (have _diffusers_load_id attribute)
- When registering None for a component, it sets attribute to None but still syncs specs with the
modular_model_index.json config
- When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained`
- component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method
"""
for name, module in kwargs.items():
# current component spec
@@ -1877,10 +2007,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
is_registered = hasattr(self, name)
is_from_pretrained = component_spec.default_creation_method == "from_pretrained"
# make sure the component is created from ComponentSpec
if module is not None and not hasattr(module, "_diffusers_load_id"):
raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.")
if module is not None:
# actual library and class name of the module
library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
@@ -1906,7 +2032,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if is_from_pretrained:
self.register_to_config(**register_dict)
setattr(self, name, module)
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
if module is not None and is_from_pretrained and self._components_manager is not None:
self._components_manager.add(name, module, self._collection)
continue
@@ -1942,7 +2068,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# finally set models
setattr(self, name, module)
# add to component manager if one is attached
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
if module is not None and is_from_pretrained and self._components_manager is not None:
self._components_manager.add(name, module, self._collection)
@property
@@ -1998,14 +2124,26 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
@property
def null_component_names(self) -> List[str]:
"""
Returns:
- List of names for components that needs to be loaded
"""
return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
@property
def component_names(self) -> List[str]:
"""
Returns:
- List of names for all components
"""
return list(self.components.keys())
@property
def pretrained_component_names(self) -> List[str]:
"""
Returns:
- List of names for from_pretrained components
"""
return [
name
for name in self._component_specs.keys()
@@ -2014,6 +2152,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
@property
def config_component_names(self) -> List[str]:
"""
Returns:
- List of names for from_config components
"""
return [
name
for name in self._component_specs.keys()
@@ -2022,44 +2164,60 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
@property
def components(self) -> Dict[str, Any]:
"""
Returns:
- Dictionary mapping component names to their objects (include both from_pretrained and from_config components)
"""
# return only components we've actually set as attributes on self
return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
def get_component_spec(self, name: str) -> ComponentSpec:
"""
Returns:
- a copy of the ComponentSpec object for the given component name
"""
return deepcopy(self._component_specs[name])
def update_components(self, **kwargs):
"""
Update components and configuration values after the loader has been instantiated.
Update components and configuration values and specs after the pipeline has been instantiated.
This method allows you to:
1. Replace existing components with new ones (e.g., updating the unet or text_encoder)
2. Update configuration values (e.g., changing requires_safety_checker flag)
1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
In addition to updating the components and configuration values as pipeline attributes, the method also updates:
- the corresponding specs in `_component_specs` and `_config_specs`
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
Args:
**kwargs: Component objects or configuration values to update:
- Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet,
text_encoder=new_encoder`)
- Configuration values: Simple values to update configuration settings (e.g.,
`requires_safety_checker=False`)
- ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call
create() method to create it
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
- Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method
i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules
(e.g., `unet=new_unet, text_encoder=new_encoder`)
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component
(e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
- Configuration values: Simple values to update configuration settings
(e.g., `requires_safety_checker=False`)
Raises:
ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute)
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
- nn.Module components without a valid `_diffusers_load_id` attribute
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
Examples:
```python
# Update multiple components at once
loader.update(unet=new_unet_model, text_encoder=new_text_encoder)
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
# Update configuration values
loader.update(requires_safety_checker=False)
pipeline.update_components(requires_safety_checker=False)
# Update both components and configs together
loader.update(unet=new_unet_model, requires_safety_checker=False)
# update with ComponentSpec objects
loader.update(
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
# Update with ComponentSpec objects (from_config only)
pipeline.update_components(
guider=ComponentSpec(
name="guider",
type_hint=ClassifierFreeGuidance,
@@ -2068,6 +2226,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
)
```
Notes:
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components()
"""
# extract component_specs_updates & config_specs_updates from `specs`
@@ -2080,28 +2243,23 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
for name, component in passed_components.items():
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.")
# YiYi TODO: remove this if we remove support for non config mixin components in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
f"The passed component '{name}' is not supported in update() method "
f"because it is not supported in `ComponentSpec.from_component()`. "
f"Please pass a ComponentSpec object instead."
)
current_component_spec = self._component_specs[name]
# warn if type changed
if current_component_spec.type_hint is not None and not isinstance(
component, current_component_spec.type_hint
):
logger.warning(
f"ModularPipeline.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
)
# update _component_specs based on the new component
new_component_spec = ComponentSpec.from_component(name, component)
if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.")
self._component_specs[name] = new_component_spec
if len(kwargs) > 0:
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
@@ -2109,7 +2267,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, component_spec in passed_component_specs.items():
if component_spec.default_creation_method == "from_pretrained":
raise ValueError(
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method"
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
)
created_components[name] = component_spec.create()
current_component_spec = self._component_specs[name]
@@ -2118,7 +2276,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
created_components[name], current_component_spec.type_hint
):
logger.warning(
f"ModularPipeline.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
)
# update _component_specs based on the user passed component_spec
self._component_specs[name] = component_spec
@@ -2145,7 +2303,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
`variant`, `revision`, etc.
"""
# if not pass any names, will not load any components
if isinstance(names, str):
names = [names]
elif not isinstance(names, list):
@@ -2393,8 +2551,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
@staticmethod
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
"""
Convert a ComponentSpec into a JSONserializable dict for saving in `modular_model_index.json`. If the
default_creation_method is not from_pretrained, return None.
Convert a ComponentSpec into a JSONserializable dict for saving as an entry in `modular_model_index.json`.
If the `default_creation_method` is not `from_pretrained`, return None.
This dict contains:
- "type_hint": Tuple[str, str]

View File

@@ -19,12 +19,13 @@ from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
from ..configuration_utils import ConfigMixin, FrozenDict
from ..utils.import_utils import is_torch_available
from ..utils import is_torch_available, logging
import torch
if is_torch_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class InsertableDict(OrderedDict):
def insert(self, key, value, index):
@@ -110,28 +111,58 @@ class ComponentSpec:
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
"""Create a ComponentSpec from a Component.
Currently supports:
- Components created with `ComponentSpec.load()` method
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
Args:
name: Name of the component
component: Component object to create spec from
Returns:
ComponentSpec object
Raises:
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
"""
# Check if component was created with ComponentSpec.load()
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
# component has a usable load_id -> from_pretrained, no warning needed
default_creation_method = "from_pretrained"
else:
# Component doesn't have a usable load_id, check if it's a nn.Module
if isinstance(component, torch.nn.Module):
raise ValueError(
"Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
)
# ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
elif isinstance(component, ConfigMixin):
# warn if component was not created with `ComponentSpec`
if not hasattr(component, "_diffusers_load_id"):
logger.warning("Component was not created using `ComponentSpec`, defaulting to `from_config` creation method")
default_creation_method = "from_config"
else:
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
raise ValueError(
f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
f" - components created with `ComponentSpec.load()` method"
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
)
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` or `load` method")
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
"We currently only support creating ComponentSpec from a component with "
"created with `ComponentSpec.load` method"
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
)
type_hint = component.__class__
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
load_spec = cls.decode_load_id(component._diffusers_load_id)
else:
load_spec = {}
return cls(
name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec