mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -232,6 +232,7 @@ class AutoOffloadStrategy:
|
||||
|
||||
|
||||
from .modular_pipeline_utils import ComponentSpec
|
||||
import uuid
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
@@ -240,26 +241,65 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
|
||||
def _get_by_collection(self, collection: str):
|
||||
"""
|
||||
Select components by collection name.
|
||||
"""
|
||||
selected_components = {}
|
||||
if collection in self.collections:
|
||||
component_ids = self.collections[collection]
|
||||
for component_id in component_ids:
|
||||
selected_components[component_id] = self.components[component_id]
|
||||
return selected_components
|
||||
|
||||
|
||||
def _get_by_load_id(self, load_id: str):
|
||||
"""
|
||||
Select components by its load_id.
|
||||
"""
|
||||
selected_components = {}
|
||||
for name, component in self.components.items():
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
||||
selected_components[name] = component
|
||||
return selected_components
|
||||
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None):
|
||||
|
||||
for comp_id, comp in self.components.items():
|
||||
if comp == component:
|
||||
logger.warning(f"Component '{name}' already exists in ComponentsManager")
|
||||
return comp_id
|
||||
|
||||
component_id = f"{name}_{uuid.uuid4()}"
|
||||
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
name = f"{name}_{component._diffusers_load_id}"
|
||||
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id)
|
||||
if components_with_same_load_id:
|
||||
existing = ", ".join(components_with_same_load_id.keys())
|
||||
logger.warning(
|
||||
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
||||
)
|
||||
|
||||
self.components[name] = component
|
||||
self.added_time[name] = time.time()
|
||||
|
||||
# add component to components manager
|
||||
self.components[component_id] = component
|
||||
self.added_time[component_id] = time.time()
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
self.collections[collection] = set()
|
||||
self.collections[collection].add(name)
|
||||
self.collections[collection].add(component_id)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
||||
return component_id
|
||||
|
||||
|
||||
def remove(self, name):
|
||||
def remove(self, name: Union[str, List[str]]):
|
||||
|
||||
if name not in self.components:
|
||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||
@@ -275,27 +315,83 @@ class ComponentsManager:
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
# YiYi TODO: looking into improving the search pattern
|
||||
def get(self, names: Union[str, List[str]]):
|
||||
# YiYi TODO: looking into improving the search pattern and refactor the code
|
||||
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None):
|
||||
"""
|
||||
Get components by name with simple pattern matching.
|
||||
Select components by name with simple pattern matching.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
Patterns:
|
||||
- "unet" : exact match
|
||||
- "!unet" : everything except exact match "unet"
|
||||
- "base_*" : everything starting with "base_"
|
||||
- "!base_*" : everything NOT starting with "base_"
|
||||
- "*unet*" : anything containing "unet"
|
||||
- "!*unet*" : anything NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything containing any of these terms
|
||||
- "!refiner|vae|unet" : anything NOT containing any of these terms
|
||||
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
|
||||
- "!unet" : everything except components with base name "unet"
|
||||
- "unet*" : anything with base name starting with "unet"
|
||||
- "!unet*" : anything with base name NOT starting with "unet"
|
||||
- "*unet*" : anything with base name containing "unet"
|
||||
- "!*unet*" : anything with base name NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
|
||||
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
|
||||
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
||||
|
||||
Returns:
|
||||
Single component if names is str and matches one component,
|
||||
dict of components if names matches multiple components or is a list
|
||||
"""
|
||||
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||
return {}
|
||||
components = self._get_by_collection(collection)
|
||||
else:
|
||||
components = self.components
|
||||
|
||||
if load_id:
|
||||
components = self._get_by_load_id(load_id)
|
||||
|
||||
if names is None:
|
||||
return components
|
||||
|
||||
# Helper to extract base name from component_id
|
||||
def get_base_name(component_id):
|
||||
parts = component_id.split('_')
|
||||
# If the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return component_id
|
||||
|
||||
# Create mapping from component_id to base_name for all components
|
||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
def matches_pattern(component_id, pattern, exact_match=False):
|
||||
"""
|
||||
Helper function to check if a component matches a pattern based on its base name.
|
||||
|
||||
Args:
|
||||
component_id: The component ID to check
|
||||
pattern: The pattern to match against
|
||||
exact_match: If True, only exact matches to base_name are considered
|
||||
"""
|
||||
base_name = base_names[component_id]
|
||||
|
||||
# Exact match with base name
|
||||
if exact_match:
|
||||
return pattern == base_name
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif pattern.endswith('*'):
|
||||
prefix = pattern[:-1]
|
||||
return base_name.startswith(prefix)
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif pattern.startswith('*'):
|
||||
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
||||
return search in base_name
|
||||
|
||||
# Exact match (no wildcards)
|
||||
else:
|
||||
return pattern == base_name
|
||||
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
@@ -305,33 +401,49 @@ class ComponentsManager:
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
|
||||
matches = {}
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
# Exact match
|
||||
elif names in self.components:
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name != names
|
||||
}
|
||||
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
|
||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting component: {names}")
|
||||
return self.components[names]
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
# If there's exactly one match and it's not a NOT pattern, return the component directly
|
||||
if len(matches) == 1 and not is_not_pattern:
|
||||
return next(iter(matches.values()))
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
@@ -342,16 +454,27 @@ class ComponentsManager:
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if (search in name) != is_not_pattern # Flip condition if not pattern
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component '{names}' not found in ComponentsManager")
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
@@ -360,7 +483,7 @@ class ComponentsManager:
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name)
|
||||
result = self.get(name, collection)
|
||||
if isinstance(result, dict):
|
||||
results.update(result)
|
||||
else:
|
||||
@@ -409,6 +532,7 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
# YiYi TODO: add quantization info
|
||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
@@ -431,14 +555,23 @@ class ComponentsManager:
|
||||
info = {
|
||||
"model_id": name,
|
||||
"added_time": self.added_time[name],
|
||||
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
||||
}
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
# Check for hook information
|
||||
has_hook = hasattr(component, "_hf_hook")
|
||||
execution_device = None
|
||||
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
||||
execution_device = component._hf_hook.execution_device
|
||||
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
"adapters": None, # Default to None
|
||||
"has_hook": has_hook,
|
||||
"execution_device": execution_device,
|
||||
})
|
||||
|
||||
# Get adapters if applicable
|
||||
@@ -472,12 +605,56 @@ class ComponentsManager:
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
# Helper to get simple name without UUID
|
||||
def get_simple_name(name):
|
||||
# Extract the base name by splitting on underscore and taking first part
|
||||
# This assumes names are in format "name_uuid"
|
||||
parts = name.split('_')
|
||||
# If we have at least 2 parts and the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return name
|
||||
|
||||
# Extract load_id if available
|
||||
def get_load_id(component):
|
||||
if hasattr(component, "_diffusers_load_id"):
|
||||
return component._diffusers_load_id
|
||||
return "N/A"
|
||||
|
||||
# Format device info compactly
|
||||
def format_device(component, info):
|
||||
if not info["has_hook"]:
|
||||
return str(getattr(component, 'device', 'N/A'))
|
||||
else:
|
||||
device = str(getattr(component, 'device', 'N/A'))
|
||||
exec_device = str(info['execution_device'] or 'N/A')
|
||||
return f"{device}({exec_device})"
|
||||
|
||||
# Get all simple names to calculate width
|
||||
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
||||
|
||||
# Get max length of load_ids for models
|
||||
load_ids = [
|
||||
get_load_id(component)
|
||||
for component in self.components.values()
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
||||
]
|
||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||
|
||||
# Collection names
|
||||
collection_names = [
|
||||
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
||||
for name in self.components.keys()
|
||||
]
|
||||
|
||||
col_widths = {
|
||||
"id": max(15, max(len(id) for id in self.components.keys())),
|
||||
"name": max(15, max(len(name) for name in simple_names)),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
"device": 10,
|
||||
"device": 15, # Reduced since using more compact format
|
||||
"dtype": 15,
|
||||
"size": 10,
|
||||
"load_id": max_load_id_len,
|
||||
"collection": max(10, max(len(str(c)) for c in collection_names))
|
||||
}
|
||||
|
||||
# Create the header lines
|
||||
@@ -494,17 +671,23 @@ class ComponentsManager:
|
||||
if models:
|
||||
output += "Models:\n" + dash_line
|
||||
# Column headers
|
||||
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
|
||||
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Model entries
|
||||
for name, component in models.items():
|
||||
info = self.get_model_info(name)
|
||||
device = str(getattr(component, "device", "N/A"))
|
||||
simple_name = get_simple_name(name)
|
||||
device_str = format_device(component, info)
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
|
||||
load_id = get_load_id(component)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
@@ -513,12 +696,16 @@ class ComponentsManager:
|
||||
output += "\n"
|
||||
output += "Other Components:\n" + dash_line
|
||||
# Column headers for other components
|
||||
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Other component entries
|
||||
for name, component in others.items():
|
||||
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
|
||||
info = self.get_model_info(name)
|
||||
simple_name = get_simple_name(name)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
@@ -526,7 +713,8 @@ class ComponentsManager:
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
output += f"\n{name}:\n"
|
||||
simple_name = get_simple_name(name)
|
||||
output += f"\n{simple_name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
|
||||
@@ -1101,7 +1101,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
# current component spec
|
||||
component_spec = self._component_specs.get(name)
|
||||
if component_spec is None:
|
||||
logger.warning(f"register_components: skipping unknown component '{name}'")
|
||||
logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'")
|
||||
continue
|
||||
|
||||
is_registered = hasattr(self, name)
|
||||
@@ -1143,17 +1143,17 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
current_module = getattr(self, name, None)
|
||||
# skip if the component is already registered with the same object
|
||||
if current_module is module:
|
||||
logger.info(f"register_components: {name} is already registered with same object, skipping")
|
||||
logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping")
|
||||
continue
|
||||
|
||||
# it module is not an instance of the expected type, still register it but with a warning
|
||||
if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint):
|
||||
logger.warning(f"register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}")
|
||||
logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}")
|
||||
|
||||
# warn if unregister
|
||||
if current_module is not None and module is None:
|
||||
logger.info(
|
||||
f"register_components: setting '{name}' to None "
|
||||
f"ModularLoader.register_components: setting '{name}' to None "
|
||||
f"(was {current_module.__class__.__name__})"
|
||||
)
|
||||
# same type, new instance → debug
|
||||
@@ -1162,7 +1162,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
and isinstance(module, current_module.__class__) \
|
||||
and current_module != module:
|
||||
logger.debug(
|
||||
f"register_components: replacing existing '{name}' "
|
||||
f"ModularLoader.register_components: replacing existing '{name}' "
|
||||
f"(same type {type(current_module).__name__}, new instance)"
|
||||
)
|
||||
|
||||
@@ -1343,7 +1343,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.")
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
|
||||
self.register_components(**passed_components)
|
||||
|
||||
Reference in New Issue
Block a user