mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make component spec loadable: add load/create method
This commit is contained in:
@@ -241,19 +241,6 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
|
||||
def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs):
|
||||
module_class = spec.type_hint
|
||||
|
||||
|
||||
if spec.revision is not None:
|
||||
kwargs["revision"] = spec.revision
|
||||
if spec.variant is not None:
|
||||
kwargs["variant"] = spec.variant
|
||||
|
||||
component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs)
|
||||
return component
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None):
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
@@ -284,21 +271,23 @@ class ComponentsManager:
|
||||
**kwargs: Additional arguments to pass to the component loader
|
||||
"""
|
||||
|
||||
if isinstance(spec, ComponentSpec):
|
||||
if spec.config is not None:
|
||||
component = spec.type_hint(**spec.config)
|
||||
self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec))
|
||||
return
|
||||
|
||||
spec = ComponentLoadSpec.from_component_spec(spec)
|
||||
|
||||
for k, v in self.components_specs.items():
|
||||
if v == spec and not force_add:
|
||||
logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.")
|
||||
return
|
||||
if isinstance(spec, ComponentSpec) and spec.repo is None:
|
||||
component = spec.create(**kwargs)
|
||||
self.add(name, component, collection=collection)
|
||||
elif isinstance(spec, ComponentSpec):
|
||||
load_spec = spec.to_load_spec()
|
||||
elif isinstance(spec, ComponentLoadSpec):
|
||||
load_spec = spec
|
||||
else:
|
||||
raise ValueError(f"Invalid spec type: {type(spec)}")
|
||||
|
||||
component = self.load_component(spec, **kwargs)
|
||||
self.add(name, component, collection=collection, load_spec=spec)
|
||||
for k, v in self.components_specs.items():
|
||||
if v == load_spec and not force_add:
|
||||
logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.")
|
||||
return
|
||||
|
||||
component = load_spec.load(**kwargs)
|
||||
self.add(name, component, collection=collection, load_spec=load_spec)
|
||||
|
||||
def remove(self, name):
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
@@ -27,26 +27,53 @@ if is_torch_available():
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component."""
|
||||
name: str
|
||||
# YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance
|
||||
type_hint: Type
|
||||
type_hint: Type # YiYi Notes: change to component_type?
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor
|
||||
repo: Optional[Union[str, List[str]]] = None
|
||||
subfolder: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
variant: Optional[str] = None
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
"""
|
||||
Create the component based on the config and additional kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments to pass to the component's __init__ method
|
||||
|
||||
Returns:
|
||||
The created component
|
||||
"""
|
||||
if self.config is not None:
|
||||
init_kwargs = self.config
|
||||
else:
|
||||
init_kwargs = {}
|
||||
return self.type_hint(**init_kwargs, **kwargs)
|
||||
|
||||
def load(self, **kwargs) -> Any:
|
||||
return self.to_load_spec().load(**kwargs)
|
||||
|
||||
def to_load_spec(self) -> "ComponentLoadSpec":
|
||||
"""Convert to a ComponentLoadSpec for storage in ComponentsManager."""
|
||||
return ComponentLoadSpec.from_component_spec(self)
|
||||
|
||||
@dataclass
|
||||
class ComponentLoadSpec:
|
||||
type_hint: type
|
||||
repo: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
variant: Optional[str] = None
|
||||
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load the component from the repository."""
|
||||
repo = kwargs.pop("repo", self.repo)
|
||||
subfolder = kwargs.pop("subfolder", self.subfolder)
|
||||
|
||||
return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_component_spec(cls, component_spec: ComponentSpec):
|
||||
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant)
|
||||
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
@@ -54,7 +81,7 @@ class ConfigSpec:
|
||||
name: str
|
||||
value: Any
|
||||
description: Optional[str] = None
|
||||
repo: Optional[Union[str, List[str]]] = None
|
||||
repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed
|
||||
|
||||
@dataclass
|
||||
class InputParam:
|
||||
|
||||
Reference in New Issue
Block a user