1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make component spec loadable: add load/create method

This commit is contained in:
yiyixuxu
2025-04-24 12:31:29 +02:00
parent d456a97420
commit a1eb9ee951
2 changed files with 52 additions and 36 deletions

View File

@@ -241,19 +241,6 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs):
module_class = spec.type_hint
if spec.revision is not None:
kwargs["revision"] = spec.revision
if spec.variant is not None:
kwargs["variant"] = spec.variant
component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs)
return component
def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None):
if name in self.components:
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
@@ -284,21 +271,23 @@ class ComponentsManager:
**kwargs: Additional arguments to pass to the component loader
"""
if isinstance(spec, ComponentSpec):
if spec.config is not None:
component = spec.type_hint(**spec.config)
self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec))
return
spec = ComponentLoadSpec.from_component_spec(spec)
for k, v in self.components_specs.items():
if v == spec and not force_add:
logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.")
return
if isinstance(spec, ComponentSpec) and spec.repo is None:
component = spec.create(**kwargs)
self.add(name, component, collection=collection)
elif isinstance(spec, ComponentSpec):
load_spec = spec.to_load_spec()
elif isinstance(spec, ComponentLoadSpec):
load_spec = spec
else:
raise ValueError(f"Invalid spec type: {type(spec)}")
component = self.load_component(spec, **kwargs)
self.add(name, component, collection=collection, load_spec=spec)
for k, v in self.components_specs.items():
if v == load_spec and not force_add:
logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.")
return
component = load_spec.load(**kwargs)
self.add(name, component, collection=collection, load_spec=load_spec)
def remove(self, name):

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import re
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ..utils.import_utils import is_torch_available
@@ -27,26 +27,53 @@ if is_torch_available():
class ComponentSpec:
"""Specification for a pipeline component."""
name: str
# YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance
type_hint: Type
type_hint: Type # YiYi Notes: change to component_type?
description: Optional[str] = None
config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor
repo: Optional[Union[str, List[str]]] = None
subfolder: Optional[str] = None
revision: Optional[str] = None
variant: Optional[str] = None
def create(self, **kwargs) -> Any:
"""
Create the component based on the config and additional kwargs.
Args:
**kwargs: Additional arguments to pass to the component's __init__ method
Returns:
The created component
"""
if self.config is not None:
init_kwargs = self.config
else:
init_kwargs = {}
return self.type_hint(**init_kwargs, **kwargs)
def load(self, **kwargs) -> Any:
return self.to_load_spec().load(**kwargs)
def to_load_spec(self) -> "ComponentLoadSpec":
"""Convert to a ComponentLoadSpec for storage in ComponentsManager."""
return ComponentLoadSpec.from_component_spec(self)
@dataclass
class ComponentLoadSpec:
type_hint: type
repo: Optional[str] = None
subfolder: Optional[str] = None
revision: Optional[str] = None
variant: Optional[str] = None
def load(self, **kwargs) -> Any:
"""Load the component from the repository."""
repo = kwargs.pop("repo", self.repo)
subfolder = kwargs.pop("subfolder", self.subfolder)
return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs)
@classmethod
def from_component_spec(cls, component_spec: ComponentSpec):
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant)
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder)
@dataclass
class ConfigSpec:
@@ -54,7 +81,7 @@ class ConfigSpec:
name: str
value: Any
description: Optional[str] = None
repo: Optional[Union[str, List[str]]] = None
repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed
@dataclass
class InputParam: