mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
components manager: change get -> search_models; add get_ids, get_components_by_ids, get_components_by_names
This commit is contained in:
@@ -232,6 +232,8 @@ class AutoOffloadStrategy:
|
||||
|
||||
|
||||
class ComponentsManager:
|
||||
_available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"]
|
||||
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
@@ -239,9 +241,10 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
|
||||
def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None):
|
||||
"""
|
||||
Lookup component_ids by name, collection, or load_id.
|
||||
Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
|
||||
Returns a set of component_ids
|
||||
"""
|
||||
if components is None:
|
||||
components = self.components
|
||||
@@ -351,15 +354,16 @@ class ComponentsManager:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get(
|
||||
# YiYi TODO: rename to search_components for now, may remove this method
|
||||
def search_components(
|
||||
self,
|
||||
names: Union[str, List[str]] = None,
|
||||
names: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
load_id: Optional[str] = None,
|
||||
as_name_component_tuples: bool = False,
|
||||
return_dict_with_names: bool = True,
|
||||
):
|
||||
"""
|
||||
Select components by name with simple pattern matching.
|
||||
Search components by name with simple pattern matching. Optionally filter by collection or load_id.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
@@ -375,34 +379,48 @@ class ComponentsManager:
|
||||
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
|
||||
instead of a dictionary with component IDs as keys
|
||||
return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
|
||||
If False, returns a dictionary with component IDs as keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping component IDs to components or list of (base_name, component) tuples if
|
||||
as_name_component_tuples=True
|
||||
Dictionary mapping component names to components if return_dict_with_names=True,
|
||||
or a dictionary mapping component IDs to components if return_dict_with_names=False
|
||||
"""
|
||||
|
||||
# select components based on collection and load_id filters
|
||||
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
|
||||
components = {k: self.components[k] for k in selected_ids}
|
||||
|
||||
# Helper to extract base name from component_id
|
||||
def get_base_name(component_id):
|
||||
parts = component_id.split("_")
|
||||
# If the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and "-" in parts[-1]:
|
||||
return "_".join(parts[:-1])
|
||||
return component_id
|
||||
|
||||
if names is None:
|
||||
if as_name_component_tuples:
|
||||
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
||||
|
||||
def get_return_dict(components, return_dict_with_names):
|
||||
"""
|
||||
Create a dictionary mapping component names to components if return_dict_with_names=True,
|
||||
or a dictionary mapping component IDs to components if return_dict_with_names=False,
|
||||
throw an error if duplicate component names are found when return_dict_with_names=True
|
||||
"""
|
||||
if return_dict_with_names:
|
||||
dict_to_return = {}
|
||||
for comp_id, comp in components.items():
|
||||
comp_name = self._id_to_name(comp_id)
|
||||
if comp_name in dict_to_return:
|
||||
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
|
||||
dict_to_return[comp_name] = comp
|
||||
return dict_to_return
|
||||
else:
|
||||
return components
|
||||
|
||||
# Create mapping from component_id to base_name for all components
|
||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
# if no names are provided, return the filtered components as it is
|
||||
if names is None:
|
||||
return get_return_dict(components, return_dict_with_names)
|
||||
|
||||
# if names is not a string, raise an error
|
||||
elif not isinstance(names, str):
|
||||
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
|
||||
|
||||
# Create mapping from component_id to base_name for components to be used for pattern matching
|
||||
base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
# Helper function to check if a component matches a pattern based on its base name
|
||||
def matches_pattern(component_id, pattern, exact_match=False):
|
||||
"""
|
||||
Helper function to check if a component matches a pattern based on its base name.
|
||||
@@ -432,113 +450,95 @@ class ComponentsManager:
|
||||
else:
|
||||
return pattern == base_name
|
||||
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith("!")
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith("!")
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if "|" in names:
|
||||
terms = names.split("|")
|
||||
matches = {}
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if "|" in names:
|
||||
terms = names.split("|")
|
||||
matches = {}
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
logger.info(
|
||||
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith("*"):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith("*"):
|
||||
search = names[1:-1] if names.endswith("*") else names[1:]
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
logger.info(
|
||||
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith("*"):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
return matches
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
||||
results.update(result)
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith("*"):
|
||||
search = names[1:-1] if names.endswith("*") else names[1:]
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
return results
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp
|
||||
for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
|
||||
return get_return_dict(matches, return_dict_with_names)
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
|
||||
if not is_accelerate_available():
|
||||
@@ -582,16 +582,18 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
# YiYi TODO: add quantization info
|
||||
# YiYi TODO: (1) add quantization info
|
||||
def get_model_info(
|
||||
self, component_id: str, fields: Optional[Union[str, List[str]]] = None
|
||||
self,
|
||||
component_id: str,
|
||||
fields: Optional[Union[str, List[str]]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
Args:
|
||||
component_id: Name of the component to get info for
|
||||
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
||||
If None, returns all fields.
|
||||
If None, uses the available_info_fields setting.
|
||||
|
||||
Returns:
|
||||
Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a
|
||||
@@ -601,6 +603,14 @@ class ComponentsManager:
|
||||
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[component_id]
|
||||
|
||||
# Validate fields if specified
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
fields = [fields]
|
||||
for field in fields:
|
||||
if field not in self._available_info_fields:
|
||||
raise ValueError(f"Field '{field}' not found in available_info_fields")
|
||||
|
||||
# Build complete info dict first
|
||||
info = {
|
||||
@@ -649,15 +659,11 @@ class ComponentsManager:
|
||||
|
||||
# If fields specified, filter info
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
# Single field requested, return just that value
|
||||
return {fields: info.get(fields)}
|
||||
else:
|
||||
# List of fields requested, return dict with just those fields
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
|
||||
return info
|
||||
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
else:
|
||||
return info
|
||||
|
||||
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
|
||||
def __repr__(self):
|
||||
# Handle empty components case
|
||||
if not self.components:
|
||||
@@ -814,9 +820,14 @@ class ComponentsManager:
|
||||
load_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a single component by name. Raises an error if multiple components match or none are found.
|
||||
Get a single component by either:
|
||||
(1) searching name (pattern matching), collection, or load_id.
|
||||
(2) passing in a component_id
|
||||
Raises an error if multiple components match or none are found.
|
||||
support pattern matching for name
|
||||
|
||||
Args:
|
||||
component_id: Optional component ID to get
|
||||
name: Component name or pattern
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
@@ -828,15 +839,16 @@ class ComponentsManager:
|
||||
ValueError: If no components match or multiple components match
|
||||
"""
|
||||
|
||||
# if component_id is provided, return the component
|
||||
if component_id is not None and (name is not None or collection is not None or load_id is not None):
|
||||
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
|
||||
elif component_id is not None:
|
||||
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
|
||||
|
||||
# search by component_id
|
||||
if component_id is not None:
|
||||
if component_id not in self.components:
|
||||
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||
return self.components[component_id]
|
||||
|
||||
results = self.get(name, collection, load_id)
|
||||
# search with name/collection/load_id
|
||||
results = self.search_components(name, collection, load_id)
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No components found matching '{name}'")
|
||||
@@ -845,20 +857,63 @@ class ComponentsManager:
|
||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||
|
||||
return next(iter(results.values()))
|
||||
|
||||
|
||||
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
Get component IDs by a list of names, optionally filtered by collection.
|
||||
"""
|
||||
ids = set()
|
||||
for name in names:
|
||||
ids.update(self._lookup_ids(name=name, collection=collection))
|
||||
return list(ids)
|
||||
|
||||
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
|
||||
"""
|
||||
Get components by a list of IDs.
|
||||
"""
|
||||
components = {id: self.components[id] for id in ids}
|
||||
|
||||
if return_dict_with_names:
|
||||
dict_to_return = {}
|
||||
for comp_id, comp in components.items():
|
||||
comp_name = self._id_to_name(comp_id)
|
||||
if comp_name in dict_to_return:
|
||||
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
|
||||
dict_to_return[comp_name] = comp
|
||||
return dict_to_return
|
||||
else:
|
||||
return components
|
||||
|
||||
def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
|
||||
"""
|
||||
Get components by a list of names, optionally filtered by collection.
|
||||
"""
|
||||
ids = self.get_ids(names, collection)
|
||||
return self.get_components_by_ids(ids)
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
For a dictionary with dot-separated keys like: {
|
||||
This function is particularly useful for IP-Adapter attention processor patterns, where multiple
|
||||
attention layers may share the same scale value. It groups dot-separated keys by their values
|
||||
and finds the shortest common prefix for each group.
|
||||
|
||||
For example, given a dictionary with IP-Adapter attention processor patterns like:
|
||||
{
|
||||
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
|
||||
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
|
||||
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
|
||||
'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor': [0.3],
|
||||
}
|
||||
|
||||
Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
|
||||
'down_blocks': [0.6], 'up_blocks': [0.3]
|
||||
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
|
||||
{
|
||||
'down_blocks.1.attentions.1.transformer_blocks': [0.6],
|
||||
'up_blocks.1.attentions.0.transformer_blocks': [0.3]
|
||||
}
|
||||
|
||||
This helps identify which attention layers share the same IP-Adapter scale values.
|
||||
"""
|
||||
# First group by values - convert lists to tuples to make them hashable
|
||||
value_to_keys = {}
|
||||
|
||||
Reference in New Issue
Block a user