mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update components manager, allow loading with spec
This commit is contained in:
@@ -230,26 +230,75 @@ class AutoOffloadStrategy:
|
||||
return hooks_to_offload
|
||||
|
||||
|
||||
|
||||
from .modular_pipeline_utils import ComponentSpec, ComponentLoadSpec
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.components_specs = OrderedDict()
|
||||
self.collections = OrderedDict() # collection_name -> set of component_names
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None):
|
||||
|
||||
def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs):
|
||||
module_class = spec.type_hint
|
||||
|
||||
|
||||
if spec.revision is not None:
|
||||
kwargs["revision"] = spec.revision
|
||||
if spec.variant is not None:
|
||||
kwargs["variant"] = spec.variant
|
||||
|
||||
component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs)
|
||||
return component
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None):
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
|
||||
self.components[name] = component
|
||||
self.added_time[name] = time.time()
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
self.collections[collection] = set()
|
||||
self.collections[collection].add(name)
|
||||
|
||||
if load_spec is not None:
|
||||
self.components_specs[name] = load_spec
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
# YiYi TODO: combine this with add method?
|
||||
def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], collection: Optional[str] = None, force_add: bool = False, **kwargs):
|
||||
"""
|
||||
Add a component to the manager.
|
||||
|
||||
Args:
|
||||
name: Name of the component in the ComponentsManager
|
||||
component: The ComponentSpec to load
|
||||
collection: Optional collection to add the component to
|
||||
force_add: If True, always add the component even if the ComponentSpec already exists
|
||||
**kwargs: Additional arguments to pass to the component loader
|
||||
"""
|
||||
|
||||
if isinstance(spec, ComponentSpec):
|
||||
if spec.config is not None:
|
||||
component = spec.type_hint(**spec.config)
|
||||
self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec))
|
||||
return
|
||||
|
||||
spec = ComponentLoadSpec.from_component_spec(spec)
|
||||
|
||||
for k, v in self.components_specs.items():
|
||||
if v == spec and not force_add:
|
||||
logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.")
|
||||
return
|
||||
|
||||
component = self.load_component(spec, **kwargs)
|
||||
self.add(name, component, collection=collection, load_spec=spec)
|
||||
|
||||
def remove(self, name):
|
||||
|
||||
@@ -538,7 +587,7 @@ class ComponentsManager:
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
# YiYi TODO: extend auto model to support non-diffusers models
|
||||
# YiYi TODO: extend AutoModel to support non-diffusers models
|
||||
if subfolder:
|
||||
from ..models import AutoModel
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
|
||||
|
||||
@@ -1158,8 +1158,7 @@ class ModularPipelineMixin:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
|
||||
# YiYi NOTE: not sure if this needs to be a ConfigMixin
|
||||
class ModularPipelineLoader(ConfigMixin):
|
||||
class ModularLoader(ConfigMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines loaders.
|
||||
|
||||
|
||||
@@ -36,6 +36,18 @@ class ComponentSpec:
|
||||
revision: Optional[str] = None
|
||||
variant: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ComponentLoadSpec:
|
||||
type_hint: type
|
||||
repo: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
variant: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_component_spec(cls, component_spec: ComponentSpec):
|
||||
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant)
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
ModularPipelineLoader,
|
||||
ModularLoader,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
InputParam,
|
||||
@@ -3578,7 +3578,7 @@ SDXL_SUPPORTED_BLOCKS = {
|
||||
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
|
||||
## (5) how to use together with Components_manager?
|
||||
class StableDiffusionXLModularLoader(
|
||||
ModularPipelineLoader,
|
||||
ModularLoader,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
Reference in New Issue
Block a user