mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update from_componeenet, update_component
This commit is contained in:
@@ -47,7 +47,8 @@ from .modular_pipeline_utils import (
|
||||
format_intermediates_short,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
from huggingface_hub import create_repo
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
@@ -1665,7 +1666,45 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the loader with a list of component specs and config specs.
|
||||
Initialize a ModularPipeline instance.
|
||||
|
||||
This method sets up the pipeline by:
|
||||
1. creating default pipeline blocks if not provided
|
||||
2. gather component and config specifications based on the pipeline blocks's requirement (e.g. expected_components, expected_configs)
|
||||
3. update the loading specs of from_pretrained components based on the modular_model_index.json file from huggingface hub if `pretrained_model_name_or_path` is provided
|
||||
4. create defaultfrom_config components and register everything
|
||||
|
||||
Args:
|
||||
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
|
||||
default blocks based on the pipeline class name.
|
||||
pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
|
||||
will load component specs (only for from_pretrained components) and config values from the saved modular_model_index.json file.
|
||||
components_manager: Optional ComponentsManager for managing multiple component cross different pipelines and apply offloading strategies.
|
||||
collection: Optional collection name for organizing components in the ComponentsManager.
|
||||
**kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Initialize with custom blocks
|
||||
pipeline = ModularPipeline(blocks=my_custom_blocks)
|
||||
|
||||
# Initialize from pretrained configuration
|
||||
pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline")
|
||||
|
||||
# Initialize with components manager
|
||||
pipeline = ModularPipeline(
|
||||
blocks=my_blocks,
|
||||
components_manager=ComponentsManager(),
|
||||
collection="my_collection"
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- If blocks is None, the method will try to find default blocks based on the pipeline class name
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with `load_default_components()`/`load_components()`
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
|
||||
@@ -1715,6 +1754,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns:
|
||||
- Dictionary mapping input names to their default values
|
||||
"""
|
||||
params = {}
|
||||
for input_param in self.blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
@@ -1722,7 +1765,40 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
|
||||
Execute the pipeline by running the pipeline blocks with the given inputs.
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", "latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Get complete pipeline state
|
||||
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
|
||||
print(state.intermediates) # All intermediate outputs
|
||||
|
||||
# Get specific output
|
||||
image = pipeline(prompt="A beautiful sunset", output="image")
|
||||
|
||||
# Get multiple specific outputs
|
||||
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
|
||||
image, latents = results["image"], results["latents"]
|
||||
|
||||
# Continue from previous state
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. `output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
@@ -1776,6 +1852,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
"""
|
||||
Load from_pretrained components using the loading specs in the config dict.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to `load_components()` method
|
||||
"""
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
@@ -1793,6 +1875,19 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
collection: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load a ModularPipeline from a huggingface hub repo.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
||||
Path to a pretrained pipeline configuration. If provided, will load component specs (only for from_pretrained components) and config values from the modular_model_index.json file.
|
||||
trust_remote_code (`bool`, optional):
|
||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||
components_manager (`ComponentsManager`, optional):
|
||||
ComponentsManager instance for managing multiple component cross different pipelines and apply offloading strategies.
|
||||
collection (`str`, optional):`
|
||||
Collection name for organizing components in the ComponentsManager.
|
||||
"""
|
||||
from ..pipelines.pipeline_loading_utils import _get_pipeline_class
|
||||
|
||||
try:
|
||||
@@ -1830,14 +1925,50 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return pipeline
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. should support save some components too! currently only modular_model_index.json is saved
|
||||
# 2. maybe order the json file to make it more readable: configs first, then components
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
"""
|
||||
Save the pipeline to a directory. It does not save components, you need to save them separately.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Path to the directory where the pipeline will be saved.
|
||||
push_to_hub (`bool`, optional):
|
||||
Whether to push the pipeline to the huggingface hub.
|
||||
**kwargs: Additional arguments passed to `save_config()` method
|
||||
|
||||
|
||||
"""
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
"""
|
||||
Returns:
|
||||
- The docstring of the pipeline blocks
|
||||
"""
|
||||
return self.blocks.doc
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
@@ -1846,25 +1977,24 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
This method is responsible for:
|
||||
1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
|
||||
2. Updates the modular_model_index.json configuration for serialization (only for from_pretrained components)
|
||||
2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only for from_pretrained components)
|
||||
3. Adds components to the component manager if one is attached (only for from_pretrained components)
|
||||
|
||||
This method is called when:
|
||||
- Components are first initialized in __init__:
|
||||
- from_pretrained components not loaded during __init__ so they are registered as None;
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or
|
||||
loader.update(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"])
|
||||
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
|
||||
loader.update_components(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. loader.load_default_components(names=["unet"])
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
E.g., register_components(unet=unet_model, text_encoder=encoder_model)
|
||||
|
||||
Notes:
|
||||
- Components must be created from ComponentSpec (have _diffusers_load_id attribute)
|
||||
- When registering None for a component, it sets attribute to None but still syncs specs with the
|
||||
modular_model_index.json config
|
||||
- When registering None for a component, it sets attribute to None but still syncs specs with the config dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- component_specs are updated to match the new component outside of this method, e.g. in `update_components()` method
|
||||
"""
|
||||
for name, module in kwargs.items():
|
||||
# current component spec
|
||||
@@ -1877,10 +2007,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_registered = hasattr(self, name)
|
||||
is_from_pretrained = component_spec.default_creation_method == "from_pretrained"
|
||||
|
||||
# make sure the component is created from ComponentSpec
|
||||
if module is not None and not hasattr(module, "_diffusers_load_id"):
|
||||
raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.")
|
||||
|
||||
if module is not None:
|
||||
# actual library and class name of the module
|
||||
library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
|
||||
@@ -1906,7 +2032,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if is_from_pretrained:
|
||||
self.register_to_config(**register_dict)
|
||||
setattr(self, name, module)
|
||||
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
|
||||
if module is not None and is_from_pretrained and self._components_manager is not None:
|
||||
self._components_manager.add(name, module, self._collection)
|
||||
continue
|
||||
|
||||
@@ -1942,7 +2068,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# finally set models
|
||||
setattr(self, name, module)
|
||||
# add to component manager if one is attached
|
||||
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
|
||||
if module is not None and is_from_pretrained and self._components_manager is not None:
|
||||
self._components_manager.add(name, module, self._collection)
|
||||
|
||||
@property
|
||||
@@ -1998,14 +2124,26 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def null_component_names(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of names for components that needs to be loaded
|
||||
"""
|
||||
return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
|
||||
|
||||
@property
|
||||
def component_names(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of names for all components
|
||||
"""
|
||||
return list(self.components.keys())
|
||||
|
||||
@property
|
||||
def pretrained_component_names(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of names for from_pretrained components
|
||||
"""
|
||||
return [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
@@ -2014,6 +2152,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def config_component_names(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of names for from_config components
|
||||
"""
|
||||
return [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
@@ -2022,44 +2164,60 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns:
|
||||
- Dictionary mapping component names to their objects (include both from_pretrained and from_config components)
|
||||
"""
|
||||
# return only components we've actually set as attributes on self
|
||||
return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
|
||||
|
||||
def get_component_spec(self, name: str) -> ComponentSpec:
|
||||
"""
|
||||
Returns:
|
||||
- a copy of the ComponentSpec object for the given component name
|
||||
"""
|
||||
return deepcopy(self._component_specs[name])
|
||||
|
||||
def update_components(self, **kwargs):
|
||||
"""
|
||||
Update components and configuration values after the loader has been instantiated.
|
||||
Update components and configuration values and specs after the pipeline has been instantiated.
|
||||
|
||||
This method allows you to:
|
||||
1. Replace existing components with new ones (e.g., updating the unet or text_encoder)
|
||||
2. Update configuration values (e.g., changing requires_safety_checker flag)
|
||||
1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
|
||||
2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
|
||||
|
||||
In addition to updating the components and configuration values as pipeline attributes, the method also updates:
|
||||
- the corresponding specs in `_component_specs` and `_config_specs`
|
||||
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
|
||||
Args:
|
||||
**kwargs: Component objects or configuration values to update:
|
||||
- Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet,
|
||||
text_encoder=new_encoder`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g.,
|
||||
`requires_safety_checker=False`)
|
||||
- ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call
|
||||
create() method to create it
|
||||
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
|
||||
- Component objects: Only supports components we can extract specs using `ComponentSpec.from_component()` method
|
||||
i.e. components created with ComponentSpec.load() or ConfigMixin subclasses that aren't nn.Modules
|
||||
(e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create() method to create a new component
|
||||
(e.g., `guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings
|
||||
(e.g., `requires_safety_checker=False`)
|
||||
|
||||
Raises:
|
||||
ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute)
|
||||
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
|
||||
- nn.Module components without a valid `_diffusers_load_id` attribute
|
||||
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Update multiple components at once
|
||||
loader.update(unet=new_unet_model, text_encoder=new_text_encoder)
|
||||
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
|
||||
|
||||
# Update configuration values
|
||||
loader.update(requires_safety_checker=False)
|
||||
pipeline.update_components(requires_safety_checker=False)
|
||||
|
||||
# Update both components and configs together
|
||||
loader.update(unet=new_unet_model, requires_safety_checker=False)
|
||||
# update with ComponentSpec objects
|
||||
loader.update(
|
||||
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
|
||||
|
||||
# Update with ComponentSpec objects (from_config only)
|
||||
pipeline.update_components(
|
||||
guider=ComponentSpec(
|
||||
name="guider",
|
||||
type_hint=ClassifierFreeGuidance,
|
||||
@@ -2068,6 +2226,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in update_components()
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
@@ -2080,28 +2243,23 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
|
||||
|
||||
for name, component in passed_components.items():
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.")
|
||||
|
||||
# YiYi TODO: remove this if we remove support for non config mixin components in `create()` method
|
||||
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
|
||||
raise ValueError(
|
||||
f"The passed component '{name}' is not supported in update() method "
|
||||
f"because it is not supported in `ComponentSpec.from_component()`. "
|
||||
f"Please pass a ComponentSpec object instead."
|
||||
)
|
||||
current_component_spec = self._component_specs[name]
|
||||
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(
|
||||
component, current_component_spec.type_hint
|
||||
):
|
||||
logger.warning(
|
||||
f"ModularPipeline.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
)
|
||||
# update _component_specs based on the new component
|
||||
new_component_spec = ComponentSpec.from_component(name, component)
|
||||
if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
|
||||
logger.warning(f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}.")
|
||||
|
||||
self._component_specs[name] = new_component_spec
|
||||
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
@@ -2109,7 +2267,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, component_spec in passed_component_specs.items():
|
||||
if component_spec.default_creation_method == "from_pretrained":
|
||||
raise ValueError(
|
||||
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method"
|
||||
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
|
||||
)
|
||||
created_components[name] = component_spec.create()
|
||||
current_component_spec = self._component_specs[name]
|
||||
@@ -2118,7 +2276,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
created_components[name], current_component_spec.type_hint
|
||||
):
|
||||
logger.warning(
|
||||
f"ModularPipeline.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
)
|
||||
# update _component_specs based on the user passed component_spec
|
||||
self._component_specs[name] = component_spec
|
||||
@@ -2145,7 +2303,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
|
||||
`variant`, `revision`, etc.
|
||||
"""
|
||||
# if not pass any names, will not load any components
|
||||
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
elif not isinstance(names, list):
|
||||
@@ -2393,8 +2551,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
@staticmethod
|
||||
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
|
||||
"""
|
||||
Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. If the
|
||||
default_creation_method is not from_pretrained, return None.
|
||||
Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`.
|
||||
If the `default_creation_method` is not `from_pretrained`, return None.
|
||||
|
||||
This dict contains:
|
||||
- "type_hint": Tuple[str, str]
|
||||
|
||||
@@ -19,12 +19,13 @@ from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils.import_utils import is_torch_available
|
||||
|
||||
from ..utils import is_torch_available, logging
|
||||
import torch
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
class InsertableDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
@@ -110,28 +111,58 @@ class ComponentSpec:
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: Any) -> Any:
|
||||
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
|
||||
"""Create a ComponentSpec from a Component.
|
||||
|
||||
Currently supports:
|
||||
- Components created with `ComponentSpec.load()` method
|
||||
- Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
|
||||
|
||||
Args:
|
||||
name: Name of the component
|
||||
component: Component object to create spec from
|
||||
|
||||
Returns:
|
||||
ComponentSpec object
|
||||
|
||||
Raises:
|
||||
ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
|
||||
"""
|
||||
|
||||
# Check if component was created with ComponentSpec.load()
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
# component has a usable load_id -> from_pretrained, no warning needed
|
||||
default_creation_method = "from_pretrained"
|
||||
else:
|
||||
# Component doesn't have a usable load_id, check if it's a nn.Module
|
||||
if isinstance(component, torch.nn.Module):
|
||||
raise ValueError(
|
||||
"Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
|
||||
)
|
||||
# ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
|
||||
elif isinstance(component, ConfigMixin):
|
||||
# warn if component was not created with `ComponentSpec`
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
logger.warning("Component was not created using `ComponentSpec`, defaulting to `from_config` creation method")
|
||||
default_creation_method = "from_config"
|
||||
else:
|
||||
# Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
|
||||
raise ValueError(
|
||||
f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
|
||||
f" - components created with `ComponentSpec.load()` method"
|
||||
f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
|
||||
)
|
||||
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError("Component is not created by `create` or `load` method")
|
||||
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
|
||||
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
|
||||
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
|
||||
raise ValueError(
|
||||
"We currently only support creating ComponentSpec from a component with "
|
||||
"created with `ComponentSpec.load` method"
|
||||
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
|
||||
)
|
||||
|
||||
type_hint = component.__class__
|
||||
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
|
||||
|
||||
if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
|
||||
config = component.config
|
||||
else:
|
||||
config = None
|
||||
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
else:
|
||||
load_spec = {}
|
||||
|
||||
return cls(
|
||||
name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
|
||||
|
||||
Reference in New Issue
Block a user