diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 8c14321ccf..4bd4c23c28 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -26,6 +26,7 @@ from ..utils import ( logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec if is_accelerate_available(): @@ -232,26 +233,36 @@ class AutoOffloadStrategy: class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): + def add(self, name, component, collection: Optional[str] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") self.components[name] = component self.added_time[name] = time.time() - + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(name) + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) def remove(self, name): + if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") return self.components.pop(name) self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) @@ -516,7 +527,7 @@ class ComponentsManager: return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -526,17 +537,12 @@ class ComponentsManager: If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend auto model to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -545,6 +551,25 @@ class ComponentsManager: f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/pipelines/modular_pipeline_util.py b/src/diffusers/pipelines/modular_pipeline_utils.py similarity index 100% rename from src/diffusers/pipelines/modular_pipeline_util.py rename to src/diffusers/pipelines/modular_pipeline_utils.py