From de8ce5274393b4213f66ce92574bd6c5d465871f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 01:09:33 +0200 Subject: [PATCH] up --- src/diffusers/pipelines/components_manager.py | 294 ++++++++++++++---- src/diffusers/pipelines/modular_pipeline.py | 12 +- 2 files changed, 247 insertions(+), 59 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index eaa2abaa7f..d2c8e9e1f1 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -232,6 +232,7 @@ class AutoOffloadStrategy: from .modular_pipeline_utils import ComponentSpec +import uuid class ComponentsManager: def __init__(self): self.components = OrderedDict() @@ -240,26 +241,65 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + def add(self, name, component, collection: Optional[str] = None): + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - name = f"{name}_{component._diffusers_load_id}" - - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) - self.components[name] = component - self.added_time[name] = time.time() + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() if collection: if collection not in self.collections: self.collections[collection] = set() - self.collections[collection].add(name) + self.collections[collection].add(component_id) if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id - def remove(self, name): + def remove(self, name: Union[str, List[str]]): if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") @@ -275,27 +315,83 @@ class ComponentsManager: if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + # YiYi TODO: looking into improving the search pattern and refactor the code + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" Returns: Single component if names is str and matches one component, dict of components if names matches multiple components or is a list """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + if names is None: + return components + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -305,33 +401,49 @@ class ComponentsManager: # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') - matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - # Exact match - elif names in self.components: + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } + if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # If there's exactly one match and it's not a NOT pattern, return the component directly + if len(matches) == 1 and not is_not_pattern: + return next(iter(matches.values())) # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -342,16 +454,27 @@ class ComponentsManager: elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") @@ -360,7 +483,7 @@ class ComponentsManager: elif isinstance(names, list): results = {} for name in names: - result = self.get(name) + result = self.get(name, collection) if isinstance(result, dict): results.update(result) else: @@ -409,6 +532,7 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False + # YiYi TODO: add quantization info def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -431,14 +555,23 @@ class ComponentsManager: info = { "model_id": name, "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -472,12 +605,56 @@ class ComponentsManager: return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) } # Create the header lines @@ -494,17 +671,23 @@ class ComponentsManager: if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" output += dash_line # Other components section @@ -513,12 +696,16 @@ class ComponentsManager: output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += dash_line # Add additional component info @@ -526,7 +713,8 @@ class ComponentsManager: for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 3ab4629343..1f1784b186 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1101,7 +1101,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): # current component spec component_spec = self._component_specs.get(name) if component_spec is None: - logger.warning(f"register_components: skipping unknown component '{name}'") + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue is_registered = hasattr(self, name) @@ -1143,17 +1143,17 @@ class ModularLoader(ConfigMixin, PushToHubMixin): current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"register_components: {name} is already registered with same object, skipping") + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue # it module is not an instance of the expected type, still register it but with a warning if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: logger.info( - f"register_components: setting '{name}' to None " + f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) # same type, new instance → debug @@ -1162,7 +1162,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): and isinstance(module, current_module.__class__) \ and current_module != module: logger.debug( - f"register_components: replacing existing '{name}' " + f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" ) @@ -1343,7 +1343,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") if len(kwargs) > 0: - raise logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") self.register_components(**passed_components)