1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update components manager, allow loading with spec

This commit is contained in:
yiyixuxu
2025-04-24 06:44:26 +02:00
parent 3b30e794d0
commit d456a97420
4 changed files with 67 additions and 7 deletions

View File

@@ -230,26 +230,75 @@ class AutoOffloadStrategy:
return hooks_to_offload
from .modular_pipeline_utils import ComponentSpec, ComponentLoadSpec
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.components_specs = OrderedDict()
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def add(self, name, component, collection: Optional[str] = None):
def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs):
module_class = spec.type_hint
if spec.revision is not None:
kwargs["revision"] = spec.revision
if spec.variant is not None:
kwargs["variant"] = spec.variant
component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs)
return component
def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None):
if name in self.components:
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
self.components[name] = component
self.added_time[name] = time.time()
if collection:
if collection not in self.collections:
self.collections[collection] = set()
self.collections[collection].add(name)
if load_spec is not None:
self.components_specs[name] = load_spec
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
self.enable_auto_cpu_offload(self._auto_offload_device)
# YiYi TODO: combine this with add method?
def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], collection: Optional[str] = None, force_add: bool = False, **kwargs):
"""
Add a component to the manager.
Args:
name: Name of the component in the ComponentsManager
component: The ComponentSpec to load
collection: Optional collection to add the component to
force_add: If True, always add the component even if the ComponentSpec already exists
**kwargs: Additional arguments to pass to the component loader
"""
if isinstance(spec, ComponentSpec):
if spec.config is not None:
component = spec.type_hint(**spec.config)
self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec))
return
spec = ComponentLoadSpec.from_component_spec(spec)
for k, v in self.components_specs.items():
if v == spec and not force_add:
logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.")
return
component = self.load_component(spec, **kwargs)
self.add(name, component, collection=collection, load_spec=spec)
def remove(self, name):
@@ -538,7 +587,7 @@ class ComponentsManager:
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
subfolder = kwargs.pop("subfolder", None)
# YiYi TODO: extend auto model to support non-diffusers models
# YiYi TODO: extend AutoModel to support non-diffusers models
if subfolder:
from ..models import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)

View File

@@ -1158,8 +1158,7 @@ class ModularPipelineMixin:
raise ValueError(f"Output '{output}' is not a valid output type")
# YiYi NOTE: not sure if this needs to be a ConfigMixin
class ModularPipelineLoader(ConfigMixin):
class ModularLoader(ConfigMixin):
"""
Base class for all Modular pipelines loaders.

View File

@@ -36,6 +36,18 @@ class ComponentSpec:
revision: Optional[str] = None
variant: Optional[str] = None
@dataclass
class ComponentLoadSpec:
type_hint: type
repo: Optional[str] = None
subfolder: Optional[str] = None
revision: Optional[str] = None
variant: Optional[str] = None
@classmethod
def from_component_spec(cls, component_spec: ComponentSpec):
return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant)
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""

View File

@@ -34,7 +34,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..modular_pipeline import (
AutoPipelineBlocks,
ModularPipelineLoader,
ModularLoader,
PipelineBlock,
PipelineState,
InputParam,
@@ -3578,7 +3578,7 @@ SDXL_SUPPORTED_BLOCKS = {
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularLoader(
ModularPipelineLoader,
ModularLoader,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,