diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py deleted file mode 100644 index bdff133e22..0000000000 --- a/src/diffusers/pipelines/components_manager.py +++ /dev/null @@ -1,862 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy - -import torch -import time -from dataclasses import dataclass - -from ..utils import ( - is_accelerate_available, - logging, -) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec - - -if is_accelerate_available(): - from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module - from accelerate.state import PartialState - from accelerate.utils import send_to_device - from accelerate.utils.memory import clear_device_cache - from accelerate.utils.modeling import convert_file_size_to_int - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# YiYi Notes: copied from modeling_utils.py (decide later where to put this) -def get_memory_footprint(self, return_buffers=True): - r""" - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to - benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch - discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 - - Arguments: - return_buffers (`bool`, *optional*, defaults to `True`): - Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are - tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm - layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 - """ - mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) - if return_buffers: - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) - mem = mem + mem_bufs - return mem - - -class CustomOffloadHook(ModelHook): - """ - A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are - on the given device. Optionally offloads other models to the CPU before the forward pass is called. - - Args: - execution_device(`str`, `int` or `torch.device`, *optional*): - The device on which the model should be executed. Will default to the MPS device if it's available, then - GPU 0 if there is a GPU, and finally to the CPU. - """ - - def __init__( - self, - execution_device: Optional[Union[str, int, torch.device]] = None, - other_hooks: Optional[List["UserCustomOffloadHook"]] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, - ): - self.execution_device = execution_device if execution_device is not None else PartialState().default_device - self.other_hooks = other_hooks - self.offload_strategy = offload_strategy - self.model_id = None - - def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): - self.offload_strategy = offload_strategy - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - """ - Add a hook to the list of hooks to consider for offloading. - """ - if self.other_hooks is None: - self.other_hooks = [] - self.other_hooks.append(hook) - - def init_hook(self, module): - return module.to("cpu") - - def pre_forward(self, module, *args, **kwargs): - if module.device != self.execution_device: - if self.other_hooks is not None: - hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] - # offload all other hooks - start_time = time.perf_counter() - if self.offload_strategy is not None: - hooks_to_offload = self.offload_strategy( - hooks=hooks_to_offload, - model_id=self.model_id, - model=module, - execution_device=self.execution_device, - ) - end_time = time.perf_counter() - logger.info( - f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" - ) - - for hook in hooks_to_offload: - logger.info( - f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" - ) - hook.offload() - - if hooks_to_offload: - clear_device_cache() - module.to(self.execution_device) - return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) - - -class UserCustomOffloadHook: - """ - A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of - the hook or remove it entirely. - """ - - def __init__(self, model_id, model, hook): - self.model_id = model_id - self.model = model - self.hook = hook - - def offload(self): - self.hook.init_hook(self.model) - - def attach(self): - add_hook_to_module(self.model, self.hook) - self.hook.model_id = self.model_id - - def remove(self): - remove_hook_from_module(self.model) - self.hook.model_id = None - - def add_other_hook(self, hook: "UserCustomOffloadHook"): - self.hook.add_other_hook(hook) - - -def custom_offload_with_hook( - model_id: str, - model: torch.nn.Module, - execution_device: Union[str, int, torch.device] = None, - offload_strategy: Optional["AutoOffloadStrategy"] = None, -): - hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) - user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) - user_hook.attach() - return user_hook - - -class AutoOffloadStrategy: - """ - Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on - the available memory on the device. - """ - - def __init__(self, memory_reserve_margin="3GB"): - self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) - - def __call__(self, hooks, model_id, model, execution_device): - if len(hooks) == 0: - return [] - - current_module_size = get_memory_footprint(model) - - mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] - mem_on_device = mem_on_device - self.memory_reserve_margin - if current_module_size < mem_on_device: - return [] - - min_memory_offload = current_module_size - mem_on_device - logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") - - # exlucde models that's not currently loaded on the device - module_sizes = dict( - sorted( - {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), - key=lambda x: x[1], - reverse=True, - ) - ) - - def search_best_candidate(module_sizes, min_memory_offload): - """ - search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a - minimum memory offload size. the combination of models should add up to the smallest modulesize that is - larger than `min_memory_offload` - """ - model_ids = list(module_sizes.keys()) - best_candidate = None - best_size = float("inf") - for r in range(1, len(model_ids) + 1): - for candidate_model_ids in combinations(model_ids, r): - candidate_size = sum( - module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids - ) - if candidate_size < min_memory_offload: - continue - else: - if best_candidate is None or candidate_size < best_size: - best_candidate = candidate_model_ids - best_size = candidate_size - - return best_candidate - - best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) - - if best_offload_model_ids is None: - # if no combination is found, meaning that we cannot meet the memory requirement, offload all models - logger.warning("no combination of models to offload to cpu is found, offloading all models") - hooks_to_offload = hooks - else: - hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] - - return hooks_to_offload - - - -from .modular_pipeline_utils import ComponentSpec -import uuid -class ComponentsManager: - def __init__(self): - self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added - self.collections = OrderedDict() # collection_name -> set of component_names - self.model_hooks = None - self._auto_offload_enabled = False - - - def _get_by_collection(self, collection: str): - """ - Select components by collection name. - """ - selected_components = {} - if collection in self.collections: - component_ids = self.collections[collection] - for component_id in component_ids: - selected_components[component_id] = self.components[component_id] - return selected_components - - - def _get_by_load_id(self, load_id: str): - """ - Select components by its load_id. - """ - selected_components = {} - for name, component in self.components.items(): - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: - selected_components[name] = component - return selected_components - - - def add(self, name, component, collection: Optional[str] = None): - - for comp_id, comp in self.components.items(): - if comp == component: - logger.warning(f"Component '{name}' already exists in ComponentsManager") - return comp_id - - component_id = f"{name}_{uuid.uuid4()}" - - if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) - if components_with_same_load_id: - existing = ", ".join(components_with_same_load_id.keys()) - logger.warning( - f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " - f"To remove a duplicate, call `components_manager.remove('')`." - ) - - - # add component to components manager - self.components[component_id] = component - self.added_time[component_id] = time.time() - if collection: - if collection not in self.collections: - self.collections[collection] = set() - self.collections[collection].add(component_id) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") - return component_id - - - def remove(self, name: Union[str, List[str]]): - - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") - return - - self.components.pop(name) - self.added_time.pop(name) - - for collection in self.collections: - if name in self.collections[collection]: - self.collections[collection].remove(name) - - if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - - def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False): - """ - Select components by name with simple pattern matching. - - Args: - names: Component name(s) or pattern(s) - Patterns: - - "unet" : match any component with base name "unet" (e.g., unet_123abc) - - "!unet" : everything except components with base name "unet" - - "unet*" : anything with base name starting with "unet" - - "!unet*" : anything with base name NOT starting with "unet" - - "*unet*" : anything with base name containing "unet" - - "!*unet*" : anything with base name NOT containing "unet" - - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" - - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" - - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" - collection: Optional collection to filter by - load_id: Optional load_id to filter by - as_name_component_tuples: If True, returns a list of (name, component) tuples using base names - instead of a dictionary with component IDs as keys - - Returns: - Dictionary mapping component IDs to components, - or list of (base_name, component) tuples if as_name_component_tuples=True - """ - - if collection: - if collection not in self.collections: - logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return [] if as_name_component_tuples else {} - components = self._get_by_collection(collection) - else: - components = self.components - - if load_id: - components = self._get_by_load_id(load_id) - - # Helper to extract base name from component_id - def get_base_name(component_id): - parts = component_id.split('_') - # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return component_id - - if names is None: - if as_name_component_tuples: - return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] - else: - return components - - # Create mapping from component_id to base_name for all components - base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - - def matches_pattern(component_id, pattern, exact_match=False): - """ - Helper function to check if a component matches a pattern based on its base name. - - Args: - component_id: The component ID to check - pattern: The pattern to match against - exact_match: If True, only exact matches to base_name are considered - """ - base_name = base_names[component_id] - - # Exact match with base name - if exact_match: - return pattern == base_name - - # Prefix match (ends with *) - elif pattern.endswith('*'): - prefix = pattern[:-1] - return base_name.startswith(prefix) - - # Contains match (starts with *) - elif pattern.startswith('*'): - search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] - return search in base_name - - # Exact match (no wildcards) - else: - return pattern == base_name - - if isinstance(names, str): - # Check if this is a "not" pattern - is_not_pattern = names.startswith('!') - if is_not_pattern: - names = names[1:] # Remove the ! prefix - - # Handle OR patterns (containing |) - if '|' in names: - terms = names.split('|') - matches = {} - - for comp_id, comp in components.items(): - # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - - # Check if any of the terms match this component - should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - - # Flip the decision if this is a NOT pattern - if is_not_pattern: - should_include = not should_include - - if should_include: - matches[comp_id] = comp - - log_msg = "NOT " if is_not_pattern else "" - match_type = "exactly matching" if exact_match else "matching any of patterns" - logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - - # Try exact match with a base name - elif any(names == base_name for base_name in base_names.values()): - # Find all components with this base name - matches = { - comp_id: comp for comp_id, comp in components.items() - if (base_names[comp_id] == names) != is_not_pattern - } - - if is_not_pattern: - logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - - # Prefix match (ends with *) - elif names.endswith('*'): - prefix = names[:-1] - matches = { - comp_id: comp for comp_id, comp in components.items() - if base_names[comp_id].startswith(prefix) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") - else: - logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - - # Contains match (starts with *) - elif names.startswith('*'): - search = names[1:-1] if names.endswith('*') else names[1:] - matches = { - comp_id: comp for comp_id, comp in components.items() - if (search in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - - # Substring match (no wildcards, but not an exact component name) - elif any(names in base_name for base_name in base_names.values()): - matches = { - comp_id: comp for comp_id, comp in components.items() - if (names in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - - else: - raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - - if not matches: - raise ValueError(f"No components found matching pattern '{names}'") - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] - else: - return matches - - elif isinstance(names, list): - results = {} - for name in names: - result = self.get(name, collection, load_id, as_name_component_tuples=False) - results.update(result) - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in results.items()] - else: - return results - - else: - raise ValueError(f"Invalid type for names: {type(names)}") - - def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"): - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): - remove_hook_from_module(component, recurse=True) - - self.disable_auto_cpu_offload() - offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) - device = torch.device(device) - if device.index is None: - device = torch.device(f"{device.type}:{0}") - all_hooks = [] - for name, component in self.components.items(): - if isinstance(component, torch.nn.Module): - hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) - all_hooks.append(hook) - - for hook in all_hooks: - other_hooks = [h for h in all_hooks if h is not hook] - for other_hook in other_hooks: - if other_hook.hook.execution_device == hook.hook.execution_device: - hook.add_other_hook(other_hook) - - self.model_hooks = all_hooks - self._auto_offload_enabled = True - self._auto_offload_device = device - - def disable_auto_cpu_offload(self): - if self.model_hooks is None: - self._auto_offload_enabled = False - return - - for hook in self.model_hooks: - hook.offload() - hook.remove() - if self.model_hooks: - clear_device_cache() - self.model_hooks = None - self._auto_offload_enabled = False - - # YiYi TODO: add quantization info - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: - """Get comprehensive information about a component. - - Args: - name: Name of the component to get info for - fields: Optional field(s) to return. Can be a string for single field or list of fields. - If None, returns all fields. - - Returns: - Dictionary containing requested component metadata. - If fields is specified, returns only those fields. - If a single field is requested as string, returns just that field's value. - """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") - - component = self.components[name] - - # Build complete info dict first - info = { - "model_id": name, - "added_time": self.added_time[name], - "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), - } - - # Additional info for torch.nn.Module components - if isinstance(component, torch.nn.Module): - # Check for hook information - has_hook = hasattr(component, "_hf_hook") - execution_device = None - if has_hook and hasattr(component._hf_hook, "execution_device"): - execution_device = component._hf_hook.execution_device - - info.update({ - "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), - "adapters": None, # Default to None - "has_hook": has_hook, - "execution_device": execution_device, - }) - - # Get adapters if applicable - if hasattr(component, "peft_config"): - info["adapters"] = list(component.peft_config.keys()) - - # Check for IP-Adapter scales - if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): - processors = copy.deepcopy(component.attn_processors) - # First check if any processor is an IP-Adapter - processor_types = [v.__class__.__name__ for v in processors.values()] - if any("IPAdapter" in ptype for ptype in processor_types): - # Then get scales only from IP-Adapter processors - scales = { - k: v.scale - for k, v in processors.items() - if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ - } - if scales: - info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) - - # If fields specified, filter info - if fields is not None: - if isinstance(fields, str): - # Single field requested, return just that value - return {fields: info.get(fields)} - else: - # List of fields requested, return dict with just those fields - return {k: v for k, v in info.items() if k in fields} - - return info - - def __repr__(self): - # Helper to get simple name without UUID - def get_simple_name(name): - # Extract the base name by splitting on underscore and taking first part - # This assumes names are in format "name_uuid" - parts = name.split('_') - # If we have at least 2 parts and the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: - return '_'.join(parts[:-1]) - return name - - # Extract load_id if available - def get_load_id(component): - if hasattr(component, "_diffusers_load_id"): - return component._diffusers_load_id - return "N/A" - - # Format device info compactly - def format_device(component, info): - if not info["has_hook"]: - return str(getattr(component, 'device', 'N/A')) - else: - device = str(getattr(component, 'device', 'N/A')) - exec_device = str(info['execution_device'] or 'N/A') - return f"{device}({exec_device})" - - # Get all simple names to calculate width - simple_names = [get_simple_name(id) for id in self.components.keys()] - - # Get max length of load_ids for models - load_ids = [ - get_load_id(component) - for component in self.components.values() - if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") - ] - max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - - # Collection names - collection_names = [ - next((coll for coll, comps in self.collections.items() if name in comps), "N/A") - for name in self.components.keys() - ] - - col_widths = { - "name": max(15, max(len(name) for name in simple_names)), - "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 15, # Reduced since using more compact format - "dtype": 15, - "size": 10, - "load_id": max_load_id_len, - "collection": max(10, max(len(str(c)) for c in collection_names)) - } - - # Create the header lines - sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" - - output = "Components:\n" + sep_line - - # Separate components into models and others - models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} - others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} - - # Models section - if models: - output += "Models:\n" + dash_line - # Column headers - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " - output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" - output += dash_line - - # Model entries - for name, component in models.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - device_str = format_device(component, info) - dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - load_id = get_load_id(component) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " - output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" - output += dash_line - - # Other components section - if others: - if models: # Add extra newline if we had models section - output += "\n" - output += "Other Components:\n" + dash_line - # Column headers for other components - output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" - output += dash_line - - # Other component entries - for name, component in others.items(): - info = self.get_model_info(name) - simple_name = get_simple_name(name) - collection = info["collection"] or "N/A" - - output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" - output += dash_line - - # Add additional component info - output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" - for name in self.components: - info = self.get_model_info(name) - if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - simple_name = get_simple_name(name) - output += f"\n{simple_name}:\n" - if info.get("adapters") is not None: - output += f" Adapters: {info['adapters']}\n" - if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" - output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" - - return output - - def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): - """ - Load components from a pretrained model and add them to the manager. - - Args: - pretrained_model_name_or_path (str): The path or identifier of the pretrained model - prefix (str, optional): Prefix to add to all component names loaded from this model. - If provided, components will be named as "{prefix}_{component_name}" - **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() - """ - subfolder = kwargs.pop("subfolder", None) - # YiYi TODO: extend AutoModel to support non-diffusers models - if subfolder: - from ..models import AutoModel - component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) - component_name = f"{prefix}_{subfolder}" if prefix else subfolder - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - else: - from ..pipelines.pipeline_utils import DiffusionPipeline - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - - if component_name not in self.components: - self.add(component_name, component) - else: - logger.warning( - f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" - f"1. remove the existing component with remove('{component_name}')\n" - f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" - ) - - def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: - """ - Get a single component by name. Raises an error if multiple components match or none are found. - - Args: - name: Component name or pattern - collection: Optional collection to filter by - load_id: Optional load_id to filter by - - Returns: - A single component - - Raises: - ValueError: If no components match or multiple components match - """ - results = self.get(name, collection, load_id) - - if not results: - raise ValueError(f"No components found matching '{name}'") - - if len(results) > 1: - raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - - return next(iter(results.values())) - -def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: - """Summarizes a dictionary by finding common prefixes that share the same value. - - For a dictionary with dot-separated keys like: - { - 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], - 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], - 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], - } - - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: - { - 'down_blocks': [0.6], - 'up_blocks': [0.3] - } - """ - # First group by values - convert lists to tuples to make them hashable - value_to_keys = {} - for key, value in d.items(): - value_tuple = tuple(value) if isinstance(value, list) else value - if value_tuple not in value_to_keys: - value_to_keys[value_tuple] = [] - value_to_keys[value_tuple].append(key) - - def find_common_prefix(keys: List[str]) -> str: - """Find the shortest common prefix among a list of dot-separated keys.""" - if not keys: - return "" - if len(keys) == 1: - return keys[0] - - # Split all keys into parts - key_parts = [k.split('.') for k in keys] - - # Find how many initial parts are common - common_length = 0 - for parts in zip(*key_parts): - if len(set(parts)) == 1: # All parts at this position are the same - common_length += 1 - else: - break - - if common_length == 0: - return "" - - # Return the common prefix - return '.'.join(key_parts[0][:common_length]) - - # Create summary by finding common prefixes for each value group - summary = {} - for value_tuple, keys in value_to_keys.items(): - prefix = find_common_prefix(keys) - if prefix: # Only add if we found a common prefix - # Convert tuple back to list if it was originally a list - value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple - summary[prefix] = value - else: - summary[""] = value # Use empty string if no common prefix - - return summary