1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update components manageer: add -collection arg, allow subfoldeer arg in from_pretrained

This commit is contained in:
yiyixuxu
2025-04-23 19:42:21 +02:00
parent 170a3c5736
commit e38f09ba97
2 changed files with 40 additions and 15 deletions

View File

@@ -26,6 +26,7 @@ from ..utils import (
logging,
)
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
if is_accelerate_available():
@@ -232,26 +233,36 @@ class AutoOffloadStrategy:
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def add(self, name, component):
def add(self, name, component, collection: Optional[str] = None):
if name in self.components:
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
self.components[name] = component
self.added_time[name] = time.time()
if collection:
if collection not in self.collections:
self.collections[collection] = set()
self.collections[collection].add(name)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
def remove(self, name):
if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager")
return
self.components.pop(name)
self.added_time.pop(name)
for collection in self.collections:
if name in self.collections[collection]:
self.collections[collection].remove(name)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
@@ -516,7 +527,7 @@ class ComponentsManager:
return output
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
"""
Load components from a pretrained model and add them to the manager.
@@ -526,17 +537,12 @@ class ComponentsManager:
If provided, components will be named as "{prefix}_{component_name}"
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
subfolder = kwargs.pop("subfolder", None)
# YiYi TODO: extend auto model to support non-diffusers models
if subfolder:
from ..models import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
if component_name not in self.components:
self.add(component_name, component)
else:
@@ -545,6 +551,25 @@ class ComponentsManager:
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
else:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.