From 655512e2cf5ff7eb9c0daa7bf1ac7e970efa62f2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 28 Jun 2025 08:35:50 +0200 Subject: [PATCH] components manager: change get -> search_models; add get_ids, get_components_by_ids, get_components_by_names --- .../modular_pipelines/components_manager.py | 343 ++++++++++-------- 1 file changed, 199 insertions(+), 144 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index df88f9570f..2e6c288ad9 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -232,6 +232,8 @@ class AutoOffloadStrategy: class ComponentsManager: + _available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"] + def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added @@ -239,9 +241,10 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False - def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): + def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None): """ - Lookup component_ids by name, collection, or load_id. + Lookup component_ids by name, collection, or load_id. Does not support pattern matching. + Returns a set of component_ids """ if components is None: components = self.components @@ -351,15 +354,16 @@ class ComponentsManager: if torch.cuda.is_available(): torch.cuda.empty_cache() - def get( + # YiYi TODO: rename to search_components for now, may remove this method + def search_components( self, - names: Union[str, List[str]] = None, + names: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, - as_name_component_tuples: bool = False, + return_dict_with_names: bool = True, ): """ - Select components by name with simple pattern matching. + Search components by name with simple pattern matching. Optionally filter by collection or load_id. Args: names: Component name(s) or pattern(s) @@ -375,34 +379,48 @@ class ComponentsManager: - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" collection: Optional collection to filter by load_id: Optional load_id to filter by - as_name_component_tuples: If True, returns a list of (name, component) tuples using base names - instead of a dictionary with component IDs as keys + return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found + If False, returns a dictionary with component IDs as keys Returns: - Dictionary mapping component IDs to components or list of (base_name, component) tuples if - as_name_component_tuples=True + Dictionary mapping component names to components if return_dict_with_names=True, + or a dictionary mapping component IDs to components if return_dict_with_names=False """ + # select components based on collection and load_id filters selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} - - # Helper to extract base name from component_id - def get_base_name(component_id): - parts = component_id.split("_") - # If the last part looks like a UUID, remove it - if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]: - return "_".join(parts[:-1]) - return component_id - - if names is None: - if as_name_component_tuples: - return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + + def get_return_dict(components, return_dict_with_names): + """ + Create a dictionary mapping component names to components if return_dict_with_names=True, + or a dictionary mapping component IDs to components if return_dict_with_names=False, + throw an error if duplicate component names are found when return_dict_with_names=True + """ + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + dict_to_return[comp_name] = comp + return dict_to_return else: return components - # Create mapping from component_id to base_name for all components - base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + # if no names are provided, return the filtered components as it is + if names is None: + return get_return_dict(components, return_dict_with_names) + + # if names is not a string, raise an error + elif not isinstance(names, str): + raise ValueError(f"Invalid type for `names: {type(names)}, only support string") + + # Create mapping from component_id to base_name for components to be used for pattern matching + base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()} + + # Helper function to check if a component matches a pattern based on its base name def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. @@ -432,113 +450,95 @@ class ComponentsManager: else: return pattern == base_name - if isinstance(names, str): - # Check if this is a "not" pattern - is_not_pattern = names.startswith("!") + # Check if this is a "not" pattern + is_not_pattern = names.startswith("!") + if is_not_pattern: + names = names[1:] # Remove the ! prefix + + # Handle OR patterns (containing |) + if "|" in names: + terms = names.split("|") + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern + } + if is_not_pattern: - names = names[1:] # Remove the ! prefix - - # Handle OR patterns (containing |) - if "|" in names: - terms = names.split("|") - matches = {} - - for comp_id, comp in components.items(): - # For OR patterns with exact names (no wildcards), we do exact matching on base names - exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) - - # Check if any of the terms match this component - should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - - # Flip the decision if this is a NOT pattern - if is_not_pattern: - should_include = not should_include - - if should_include: - matches[comp_id] = comp - - log_msg = "NOT " if is_not_pattern else "" - match_type = "exactly matching" if exact_match else "matching any of patterns" - logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - - # Try exact match with a base name - elif any(names == base_name for base_name in base_names.values()): - # Find all components with this base name - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (base_names[comp_id] == names) != is_not_pattern - } - - if is_not_pattern: - logger.info( - f"Getting all components except those with base name '{names}': {list(matches.keys())}" - ) - else: - logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - - # Prefix match (ends with *) - elif names.endswith("*"): - prefix = names[:-1] - matches = { - comp_id: comp - for comp_id, comp in components.items() - if base_names[comp_id].startswith(prefix) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") - else: - logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - - # Contains match (starts with *) - elif names.startswith("*"): - search = names[1:-1] if names.endswith("*") else names[1:] - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (search in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - - # Substring match (no wildcards, but not an exact component name) - elif any(names in base_name for base_name in base_names.values()): - matches = { - comp_id: comp - for comp_id, comp in components.items() - if (names in base_names[comp_id]) != is_not_pattern - } - if is_not_pattern: - logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") - else: - logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - + logger.info( + f"Getting all components except those with base name '{names}': {list(matches.keys())}" + ) else: - raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - if not matches: - raise ValueError(f"No components found matching pattern '{names}'") - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + # Prefix match (ends with *) + elif names.endswith("*"): + prefix = names[:-1] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") else: - return matches + logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - elif isinstance(names, list): - results = {} - for name in names: - result = self.get(name, collection, load_id, as_name_component_tuples=False) - results.update(result) - - if as_name_component_tuples: - return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + # Contains match (starts with *) + elif names.startswith("*"): + search = names[1:-1] if names.endswith("*") else names[1:] + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: - return results + logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp + for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") else: - raise ValueError(f"Invalid type for names: {type(names)}") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") + + if not matches: + raise ValueError(f"No components found matching pattern '{names}'") + + return get_return_dict(matches, return_dict_with_names) def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"): if not is_accelerate_available(): @@ -582,16 +582,18 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False - # YiYi TODO: add quantization info + # YiYi TODO: (1) add quantization info def get_model_info( - self, component_id: str, fields: Optional[Union[str, List[str]]] = None + self, + component_id: str, + fields: Optional[Union[str, List[str]]] = None, ) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. - If None, returns all fields. + If None, uses the available_info_fields setting. Returns: Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a @@ -601,6 +603,14 @@ class ComponentsManager: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] + + # Validate fields if specified + if fields is not None: + if isinstance(fields, str): + fields = [fields] + for field in fields: + if field not in self._available_info_fields: + raise ValueError(f"Field '{field}' not found in available_info_fields") # Build complete info dict first info = { @@ -649,15 +659,11 @@ class ComponentsManager: # If fields specified, filter info if fields is not None: - if isinstance(fields, str): - # Single field requested, return just that value - return {fields: info.get(fields)} - else: - # List of fields requested, return dict with just those fields - return {k: v for k, v in info.items() if k in fields} - - return info - + return {k: v for k, v in info.items() if k in fields} + else: + return info + + # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table def __repr__(self): # Handle empty components case if not self.components: @@ -814,9 +820,14 @@ class ComponentsManager: load_id: Optional[str] = None, ) -> Any: """ - Get a single component by name. Raises an error if multiple components match or none are found. + Get a single component by either: + (1) searching name (pattern matching), collection, or load_id. + (2) passing in a component_id + Raises an error if multiple components match or none are found. + support pattern matching for name Args: + component_id: Optional component ID to get name: Component name or pattern collection: Optional collection to filter by load_id: Optional load_id to filter by @@ -828,15 +839,16 @@ class ComponentsManager: ValueError: If no components match or multiple components match """ - # if component_id is provided, return the component if component_id is not None and (name is not None or collection is not None or load_id is not None): - raise ValueError(" if component_id is provided, name, collection, and load_id must be None") - elif component_id is not None: + raise ValueError("If searching by component_id, do not pass name, collection, or load_id") + + # search by component_id + if component_id is not None: if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") return self.components[component_id] - - results = self.get(name, collection, load_id) + # search with name/collection/load_id + results = self.search_components(name, collection, load_id) if not results: raise ValueError(f"No components found matching '{name}'") @@ -845,20 +857,63 @@ class ComponentsManager: raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") return next(iter(results.values())) + + def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): + """ + Get component IDs by a list of names, optionally filtered by collection. + """ + ids = set() + for name in names: + ids.update(self._lookup_ids(name=name, collection=collection)) + return list(ids) + + def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): + """ + Get components by a list of IDs. + """ + components = {id: self.components[id] for id in ids} + + if return_dict_with_names: + dict_to_return = {} + for comp_id, comp in components.items(): + comp_name = self._id_to_name(comp_id) + if comp_name in dict_to_return: + raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + dict_to_return[comp_name] = comp + return dict_to_return + else: + return components + + def get_components_by_names(self, names: List[str], collection: Optional[str] = None): + """ + Get components by a list of names, optionally filtered by collection. + """ + ids = self.get_ids(names, collection) + return self.get_components_by_ids(ids) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. - For a dictionary with dot-separated keys like: { + This function is particularly useful for IP-Adapter attention processor patterns, where multiple + attention layers may share the same scale value. It groups dot-separated keys by their values + and finds the shortest common prefix for each group. + + For example, given a dictionary with IP-Adapter attention processor patterns like: + { 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], + 'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor': [0.3], } - Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { - 'down_blocks': [0.6], 'up_blocks': [0.3] + Returns a dictionary where keys are the shortest common prefixes and values are their shared values: + { + 'down_blocks.1.attentions.1.transformer_blocks': [0.6], + 'up_blocks.1.attentions.0.transformer_blocks': [0.3] } + + This helps identify which attention layers share the same IP-Adapter scale values. """ # First group by values - convert lists to tuples to make them hashable value_to_keys = {}