diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 5cf471314d..f9a039ddaa 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -241,19 +241,6 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False - - def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs): - module_class = spec.type_hint - - - if spec.revision is not None: - kwargs["revision"] = spec.revision - if spec.variant is not None: - kwargs["variant"] = spec.variant - - component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs) - return component - def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") @@ -284,21 +271,23 @@ class ComponentsManager: **kwargs: Additional arguments to pass to the component loader """ - if isinstance(spec, ComponentSpec): - if spec.config is not None: - component = spec.type_hint(**spec.config) - self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec)) - return - - spec = ComponentLoadSpec.from_component_spec(spec) - - for k, v in self.components_specs.items(): - if v == spec and not force_add: - logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.") - return + if isinstance(spec, ComponentSpec) and spec.repo is None: + component = spec.create(**kwargs) + self.add(name, component, collection=collection) + elif isinstance(spec, ComponentSpec): + load_spec = spec.to_load_spec() + elif isinstance(spec, ComponentLoadSpec): + load_spec = spec + else: + raise ValueError(f"Invalid spec type: {type(spec)}") - component = self.load_component(spec, **kwargs) - self.add(name, component, collection=collection, load_spec=spec) + for k, v in self.components_specs.items(): + if v == load_spec and not force_add: + logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.") + return + + component = load_spec.load(**kwargs) + self.add(name, component, collection=collection, load_spec=load_spec) def remove(self, name): diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index 282b94bb08..bb8cc1283e 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..utils.import_utils import is_torch_available @@ -27,26 +27,53 @@ if is_torch_available(): class ComponentSpec: """Specification for a pipeline component.""" name: str - # YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance - type_hint: Type + type_hint: Type # YiYi Notes: change to component_type? description: Optional[str] = None config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor repo: Optional[Union[str, List[str]]] = None subfolder: Optional[str] = None - revision: Optional[str] = None - variant: Optional[str] = None + + def create(self, **kwargs) -> Any: + """ + Create the component based on the config and additional kwargs. + + Args: + **kwargs: Additional arguments to pass to the component's __init__ method + + Returns: + The created component + """ + if self.config is not None: + init_kwargs = self.config + else: + init_kwargs = {} + return self.type_hint(**init_kwargs, **kwargs) + + def load(self, **kwargs) -> Any: + return self.to_load_spec().load(**kwargs) + + def to_load_spec(self) -> "ComponentLoadSpec": + """Convert to a ComponentLoadSpec for storage in ComponentsManager.""" + return ComponentLoadSpec.from_component_spec(self) @dataclass class ComponentLoadSpec: type_hint: type repo: Optional[str] = None subfolder: Optional[str] = None - revision: Optional[str] = None - variant: Optional[str] = None + def load(self, **kwargs) -> Any: + """Load the component from the repository.""" + repo = kwargs.pop("repo", self.repo) + subfolder = kwargs.pop("subfolder", self.subfolder) + + return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs) + + @classmethod def from_component_spec(cls, component_spec: ComponentSpec): - return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant) + return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder) + @dataclass class ConfigSpec: @@ -54,7 +81,7 @@ class ConfigSpec: name: str value: Any description: Optional[str] = None - repo: Optional[Union[str, List[str]]] = None + repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed @dataclass class InputParam: