1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

components manager: change get -> search_models; add get_ids, get_components_by_ids, get_components_by_names

This commit is contained in:
yiyixuxu
2025-06-28 08:35:50 +02:00
parent f63d62e091
commit 655512e2cf

View File

@@ -232,6 +232,8 @@ class AutoOffloadStrategy:
class ComponentsManager:
_available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"]
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
@@ -239,9 +241,10 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None):
"""
Lookup component_ids by name, collection, or load_id.
Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
Returns a set of component_ids
"""
if components is None:
components = self.components
@@ -351,15 +354,16 @@ class ComponentsManager:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get(
# YiYi TODO: rename to search_components for now, may remove this method
def search_components(
self,
names: Union[str, List[str]] = None,
names: Optional[str] = None,
collection: Optional[str] = None,
load_id: Optional[str] = None,
as_name_component_tuples: bool = False,
return_dict_with_names: bool = True,
):
"""
Select components by name with simple pattern matching.
Search components by name with simple pattern matching. Optionally filter by collection or load_id.
Args:
names: Component name(s) or pattern(s)
@@ -375,34 +379,48 @@ class ComponentsManager:
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
collection: Optional collection to filter by
load_id: Optional load_id to filter by
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
instead of a dictionary with component IDs as keys
return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
If False, returns a dictionary with component IDs as keys
Returns:
Dictionary mapping component IDs to components or list of (base_name, component) tuples if
as_name_component_tuples=True
Dictionary mapping component names to components if return_dict_with_names=True,
or a dictionary mapping component IDs to components if return_dict_with_names=False
"""
# select components based on collection and load_id filters
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
# Helper to extract base name from component_id
def get_base_name(component_id):
parts = component_id.split("_")
# If the last part looks like a UUID, remove it
if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]:
return "_".join(parts[:-1])
return component_id
if names is None:
if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
def get_return_dict(components, return_dict_with_names):
"""
Create a dictionary mapping component names to components if return_dict_with_names=True,
or a dictionary mapping component IDs to components if return_dict_with_names=False,
throw an error if duplicate component names are found when return_dict_with_names=True
"""
if return_dict_with_names:
dict_to_return = {}
for comp_id, comp in components.items():
comp_name = self._id_to_name(comp_id)
if comp_name in dict_to_return:
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
dict_to_return[comp_name] = comp
return dict_to_return
else:
return components
# Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
# if no names are provided, return the filtered components as it is
if names is None:
return get_return_dict(components, return_dict_with_names)
# if names is not a string, raise an error
elif not isinstance(names, str):
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
# Create mapping from component_id to base_name for components to be used for pattern matching
base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}
# Helper function to check if a component matches a pattern based on its base name
def matches_pattern(component_id, pattern, exact_match=False):
"""
Helper function to check if a component matches a pattern based on its base name.
@@ -432,113 +450,95 @@ class ComponentsManager:
else:
return pattern == base_name
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith("!")
# Check if this is a "not" pattern
is_not_pattern = names.startswith("!")
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if "|" in names:
terms = names.split("|")
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if "|" in names:
terms = names.split("|")
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
logger.info(
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
)
else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *)
elif names.endswith("*"):
prefix = names[:-1]
matches = {
comp_id: comp
for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *)
elif names.startswith("*"):
search = names[1:-1] if names.endswith("*") else names[1:]
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
logger.info(
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
)
else:
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
# Prefix match (ends with *)
elif names.endswith("*"):
prefix = names[:-1]
matches = {
comp_id: comp
for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
return matches
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result)
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
# Contains match (starts with *)
elif names.startswith("*"):
search = names[1:-1] if names.endswith("*") else names[1:]
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
return results
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp
for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else:
raise ValueError(f"Invalid type for names: {type(names)}")
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
return get_return_dict(matches, return_dict_with_names)
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
if not is_accelerate_available():
@@ -582,16 +582,18 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
# YiYi TODO: add quantization info
# YiYi TODO: (1) add quantization info
def get_model_info(
self, component_id: str, fields: Optional[Union[str, List[str]]] = None
self,
component_id: str,
fields: Optional[Union[str, List[str]]] = None,
) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
Args:
component_id: Name of the component to get info for
fields: Optional field(s) to return. Can be a string for single field or list of fields.
If None, returns all fields.
If None, uses the available_info_fields setting.
Returns:
Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a
@@ -601,6 +603,14 @@ class ComponentsManager:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[component_id]
# Validate fields if specified
if fields is not None:
if isinstance(fields, str):
fields = [fields]
for field in fields:
if field not in self._available_info_fields:
raise ValueError(f"Field '{field}' not found in available_info_fields")
# Build complete info dict first
info = {
@@ -649,15 +659,11 @@ class ComponentsManager:
# If fields specified, filter info
if fields is not None:
if isinstance(fields, str):
# Single field requested, return just that value
return {fields: info.get(fields)}
else:
# List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields}
return info
return {k: v for k, v in info.items() if k in fields}
else:
return info
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
def __repr__(self):
# Handle empty components case
if not self.components:
@@ -814,9 +820,14 @@ class ComponentsManager:
load_id: Optional[str] = None,
) -> Any:
"""
Get a single component by name. Raises an error if multiple components match or none are found.
Get a single component by either:
(1) searching name (pattern matching), collection, or load_id.
(2) passing in a component_id
Raises an error if multiple components match or none are found.
support pattern matching for name
Args:
component_id: Optional component ID to get
name: Component name or pattern
collection: Optional collection to filter by
load_id: Optional load_id to filter by
@@ -828,15 +839,16 @@ class ComponentsManager:
ValueError: If no components match or multiple components match
"""
# if component_id is provided, return the component
if component_id is not None and (name is not None or collection is not None or load_id is not None):
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
elif component_id is not None:
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
# search by component_id
if component_id is not None:
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
return self.components[component_id]
results = self.get(name, collection, load_id)
# search with name/collection/load_id
results = self.search_components(name, collection, load_id)
if not results:
raise ValueError(f"No components found matching '{name}'")
@@ -845,20 +857,63 @@ class ComponentsManager:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
"""
Get component IDs by a list of names, optionally filtered by collection.
"""
ids = set()
for name in names:
ids.update(self._lookup_ids(name=name, collection=collection))
return list(ids)
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
"""
Get components by a list of IDs.
"""
components = {id: self.components[id] for id in ids}
if return_dict_with_names:
dict_to_return = {}
for comp_id, comp in components.items():
comp_name = self._id_to_name(comp_id)
if comp_name in dict_to_return:
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
dict_to_return[comp_name] = comp
return dict_to_return
else:
return components
def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
"""
Get components by a list of names, optionally filtered by collection.
"""
ids = self.get_ids(names, collection)
return self.get_components_by_ids(ids)
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.
For a dictionary with dot-separated keys like: {
This function is particularly useful for IP-Adapter attention processor patterns, where multiple
attention layers may share the same scale value. It groups dot-separated keys by their values
and finds the shortest common prefix for each group.
For example, given a dictionary with IP-Adapter attention processor patterns like:
{
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor': [0.3],
}
Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
'down_blocks': [0.6], 'up_blocks': [0.3]
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
{
'down_blocks.1.attentions.1.transformer_blocks': [0.6],
'up_blocks.1.attentions.0.transformer_blocks': [0.3]
}
This helps identify which attention layers share the same IP-Adapter scale values.
"""
# First group by values - convert lists to tuples to make them hashable
value_to_keys = {}