mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[modular diffusers] introducing ModularLoader (#11462)
* cfg; slg; pag; sdxl without controlnet --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -249,7 +249,7 @@ else:
|
||||
"KarrasVePipeline",
|
||||
"LDMPipeline",
|
||||
"LDMSuperResolutionPipeline",
|
||||
"ModularPipeline",
|
||||
"ModularLoader",
|
||||
"PNDMPipeline",
|
||||
"RePaintPipeline",
|
||||
"ScoreSdeVePipeline",
|
||||
@@ -502,7 +502,7 @@ else:
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLModularLoader",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
@@ -840,7 +840,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
ModularPipeline,
|
||||
ModularLoader,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
@@ -1071,7 +1071,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
StableDiffusionXLPAGPipeline,
|
||||
|
||||
@@ -46,7 +46,7 @@ else:
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["ModularPipeline"]
|
||||
_import_structure["modular_pipeline"] = ["ModularLoader"]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
@@ -329,7 +329,7 @@ else:
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLModularLoader",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
)
|
||||
@@ -468,7 +468,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .modular_pipeline import ModularPipeline
|
||||
from .modular_pipeline import ModularLoader
|
||||
from .pipeline_utils import (
|
||||
AudioPipelineOutput,
|
||||
DiffusionPipeline,
|
||||
@@ -693,7 +693,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from .modular_pipeline_utils import ComponentSpec
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -229,54 +230,175 @@ class AutoOffloadStrategy:
|
||||
return hooks_to_offload
|
||||
|
||||
|
||||
|
||||
from .modular_pipeline_utils import ComponentSpec
|
||||
import uuid
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.collections = OrderedDict() # collection_name -> set of component_names
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def add(self, name, component):
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
self.components[name] = component
|
||||
self.added_time[name] = time.time()
|
||||
|
||||
def _get_by_collection(self, collection: str):
|
||||
"""
|
||||
Select components by collection name.
|
||||
"""
|
||||
selected_components = {}
|
||||
if collection in self.collections:
|
||||
component_ids = self.collections[collection]
|
||||
for component_id in component_ids:
|
||||
selected_components[component_id] = self.components[component_id]
|
||||
return selected_components
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
def _get_by_load_id(self, load_id: str):
|
||||
"""
|
||||
Select components by its load_id.
|
||||
"""
|
||||
selected_components = {}
|
||||
for name, component in self.components.items():
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
||||
selected_components[name] = component
|
||||
return selected_components
|
||||
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None):
|
||||
|
||||
for comp_id, comp in self.components.items():
|
||||
if comp == component:
|
||||
logger.warning(f"Component '{name}' already exists in ComponentsManager")
|
||||
return comp_id
|
||||
|
||||
component_id = f"{name}_{uuid.uuid4()}"
|
||||
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id)
|
||||
if components_with_same_load_id:
|
||||
existing = ", ".join(components_with_same_load_id.keys())
|
||||
logger.warning(
|
||||
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
||||
)
|
||||
|
||||
|
||||
# add component to components manager
|
||||
self.components[component_id] = component
|
||||
self.added_time[component_id] = time.time()
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
self.collections[collection] = set()
|
||||
self.collections[collection].add(component_id)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
||||
return component_id
|
||||
|
||||
|
||||
def remove(self, name: Union[str, List[str]]):
|
||||
|
||||
def remove(self, name):
|
||||
if name not in self.components:
|
||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||
return
|
||||
|
||||
self.components.pop(name)
|
||||
self.added_time.pop(name)
|
||||
|
||||
for collection in self.collections:
|
||||
if name in self.collections[collection]:
|
||||
self.collections[collection].remove(name)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
# YiYi TODO: looking into improving the search pattern
|
||||
def get(self, names: Union[str, List[str]]):
|
||||
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
|
||||
as_name_component_tuples: bool = False):
|
||||
"""
|
||||
Get components by name with simple pattern matching.
|
||||
Select components by name with simple pattern matching.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
Patterns:
|
||||
- "unet" : exact match
|
||||
- "!unet" : everything except exact match "unet"
|
||||
- "base_*" : everything starting with "base_"
|
||||
- "!base_*" : everything NOT starting with "base_"
|
||||
- "*unet*" : anything containing "unet"
|
||||
- "!*unet*" : anything NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything containing any of these terms
|
||||
- "!refiner|vae|unet" : anything NOT containing any of these terms
|
||||
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
|
||||
- "!unet" : everything except components with base name "unet"
|
||||
- "unet*" : anything with base name starting with "unet"
|
||||
- "!unet*" : anything with base name NOT starting with "unet"
|
||||
- "*unet*" : anything with base name containing "unet"
|
||||
- "!*unet*" : anything with base name NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
|
||||
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
|
||||
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
|
||||
instead of a dictionary with component IDs as keys
|
||||
|
||||
Returns:
|
||||
Single component if names is str and matches one component,
|
||||
dict of components if names matches multiple components or is a list
|
||||
Dictionary mapping component IDs to components,
|
||||
or list of (base_name, component) tuples if as_name_component_tuples=True
|
||||
"""
|
||||
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||
return [] if as_name_component_tuples else {}
|
||||
components = self._get_by_collection(collection)
|
||||
else:
|
||||
components = self.components
|
||||
|
||||
if load_id:
|
||||
components = self._get_by_load_id(load_id)
|
||||
|
||||
# Helper to extract base name from component_id
|
||||
def get_base_name(component_id):
|
||||
parts = component_id.split('_')
|
||||
# If the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return component_id
|
||||
|
||||
if names is None:
|
||||
if as_name_component_tuples:
|
||||
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
||||
else:
|
||||
return components
|
||||
|
||||
# Create mapping from component_id to base_name for all components
|
||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
def matches_pattern(component_id, pattern, exact_match=False):
|
||||
"""
|
||||
Helper function to check if a component matches a pattern based on its base name.
|
||||
|
||||
Args:
|
||||
component_id: The component ID to check
|
||||
pattern: The pattern to match against
|
||||
exact_match: If True, only exact matches to base_name are considered
|
||||
"""
|
||||
base_name = base_names[component_id]
|
||||
|
||||
# Exact match with base name
|
||||
if exact_match:
|
||||
return pattern == base_name
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif pattern.endswith('*'):
|
||||
prefix = pattern[:-1]
|
||||
return base_name.startswith(prefix)
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif pattern.startswith('*'):
|
||||
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
||||
return search in base_name
|
||||
|
||||
# Exact match (no wildcards)
|
||||
else:
|
||||
return pattern == base_name
|
||||
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
@@ -286,33 +408,45 @@ class ComponentsManager:
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
|
||||
matches = {}
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
# Exact match
|
||||
elif names in self.components:
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name != names
|
||||
}
|
||||
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
|
||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting component: {names}")
|
||||
return self.components[names]
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
@@ -323,30 +457,46 @@ class ComponentsManager:
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if (search in name) != is_not_pattern # Flip condition if not pattern
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component '{names}' not found in ComponentsManager")
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
return matches if len(matches) > 1 else next(iter(matches.values()))
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
||||
else:
|
||||
return matches
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name)
|
||||
if isinstance(result, dict):
|
||||
results.update(result)
|
||||
else:
|
||||
results[name] = result
|
||||
return results
|
||||
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
||||
results.update(result)
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
||||
else:
|
||||
return results
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
@@ -390,6 +540,7 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
# YiYi TODO: add quantization info
|
||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
@@ -412,14 +563,23 @@ class ComponentsManager:
|
||||
info = {
|
||||
"model_id": name,
|
||||
"added_time": self.added_time[name],
|
||||
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
||||
}
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
# Check for hook information
|
||||
has_hook = hasattr(component, "_hf_hook")
|
||||
execution_device = None
|
||||
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
||||
execution_device = component._hf_hook.execution_device
|
||||
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
"adapters": None, # Default to None
|
||||
"has_hook": has_hook,
|
||||
"execution_device": execution_device,
|
||||
})
|
||||
|
||||
# Get adapters if applicable
|
||||
@@ -453,12 +613,56 @@ class ComponentsManager:
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
# Helper to get simple name without UUID
|
||||
def get_simple_name(name):
|
||||
# Extract the base name by splitting on underscore and taking first part
|
||||
# This assumes names are in format "name_uuid"
|
||||
parts = name.split('_')
|
||||
# If we have at least 2 parts and the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return name
|
||||
|
||||
# Extract load_id if available
|
||||
def get_load_id(component):
|
||||
if hasattr(component, "_diffusers_load_id"):
|
||||
return component._diffusers_load_id
|
||||
return "N/A"
|
||||
|
||||
# Format device info compactly
|
||||
def format_device(component, info):
|
||||
if not info["has_hook"]:
|
||||
return str(getattr(component, 'device', 'N/A'))
|
||||
else:
|
||||
device = str(getattr(component, 'device', 'N/A'))
|
||||
exec_device = str(info['execution_device'] or 'N/A')
|
||||
return f"{device}({exec_device})"
|
||||
|
||||
# Get all simple names to calculate width
|
||||
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
||||
|
||||
# Get max length of load_ids for models
|
||||
load_ids = [
|
||||
get_load_id(component)
|
||||
for component in self.components.values()
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
||||
]
|
||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||
|
||||
# Collection names
|
||||
collection_names = [
|
||||
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
||||
for name in self.components.keys()
|
||||
]
|
||||
|
||||
col_widths = {
|
||||
"id": max(15, max(len(id) for id in self.components.keys())),
|
||||
"name": max(15, max(len(name) for name in simple_names)),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
"device": 10,
|
||||
"device": 15, # Reduced since using more compact format
|
||||
"dtype": 15,
|
||||
"size": 10,
|
||||
"load_id": max_load_id_len,
|
||||
"collection": max(10, max(len(str(c)) for c in collection_names))
|
||||
}
|
||||
|
||||
# Create the header lines
|
||||
@@ -475,17 +679,23 @@ class ComponentsManager:
|
||||
if models:
|
||||
output += "Models:\n" + dash_line
|
||||
# Column headers
|
||||
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
|
||||
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Model entries
|
||||
for name, component in models.items():
|
||||
info = self.get_model_info(name)
|
||||
device = str(getattr(component, "device", "N/A"))
|
||||
simple_name = get_simple_name(name)
|
||||
device_str = format_device(component, info)
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
|
||||
load_id = get_load_id(component)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
@@ -494,12 +704,16 @@ class ComponentsManager:
|
||||
output += "\n"
|
||||
output += "Other Components:\n" + dash_line
|
||||
# Column headers for other components
|
||||
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Other component entries
|
||||
for name, component in others.items():
|
||||
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
|
||||
info = self.get_model_info(name)
|
||||
simple_name = get_simple_name(name)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
@@ -507,7 +721,8 @@ class ComponentsManager:
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
output += f"\n{name}:\n"
|
||||
simple_name = get_simple_name(name)
|
||||
output += f"\n{simple_name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
@@ -516,7 +731,7 @@ class ComponentsManager:
|
||||
|
||||
return output
|
||||
|
||||
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Load components from a pretrained model and add them to the manager.
|
||||
|
||||
@@ -526,17 +741,12 @@ class ComponentsManager:
|
||||
If provided, components will be named as "{prefix}_{component_name}"
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
# YiYi TODO: extend AutoModel to support non-diffusers models
|
||||
if subfolder:
|
||||
from ..models import AutoModel
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
|
||||
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
@@ -545,6 +755,50 @@ class ComponentsManager:
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
|
||||
def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Get a single component by name. Raises an error if multiple components match or none are found.
|
||||
|
||||
Args:
|
||||
name: Component name or pattern
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
|
||||
Returns:
|
||||
A single component
|
||||
|
||||
Raises:
|
||||
ValueError: If no components match or multiple components match
|
||||
"""
|
||||
results = self.get(name, collection, load_id)
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No components found matching '{name}'")
|
||||
|
||||
if len(results) > 1:
|
||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||
|
||||
return next(iter(results.values()))
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
592
src/diffusers/pipelines/modular_pipeline_utils.py
Normal file
592
src/diffusers/pipelines/modular_pipeline_utils.py
Normal file
@@ -0,0 +1,592 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import inspect
|
||||
from dataclasses import dataclass, asdict, field, fields
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
from ..configuration_utils import FrozenDict, ConfigMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. validate the dataclass fields
|
||||
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
|
||||
@dataclass
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component.
|
||||
|
||||
A component can be created in two ways:
|
||||
1. From scratch using __init__ with a config dict
|
||||
2. using `from_pretrained`
|
||||
|
||||
Attributes:
|
||||
name: Name of the component
|
||||
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
||||
description: Optional description of the component
|
||||
config: Optional config dict for __init__ creation
|
||||
repo: Optional repo path for from_pretrained creation
|
||||
subfolder: Optional subfolder in repo
|
||||
variant: Optional variant in repo
|
||||
revision: Optional revision in repo
|
||||
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
||||
"""
|
||||
name: Optional[str] = None
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict[str, Any]] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||
return hash((self.name, self.load_id, self.default_creation_method))
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Compare ComponentSpec objects based on name and load_id."""
|
||||
if not isinstance(other, ComponentSpec):
|
||||
return False
|
||||
return (self.name == other.name and
|
||||
self.load_id == other.load_id and
|
||||
self.default_creation_method == other.default_creation_method)
|
||||
|
||||
@classmethod
|
||||
def from_component(cls, name: str, component: torch.nn.Module) -> Any:
|
||||
"""Create a ComponentSpec from a Component created by `create` method."""
|
||||
|
||||
if not hasattr(component, "_diffusers_load_id"):
|
||||
raise ValueError("Component is not created by `create` method")
|
||||
|
||||
type_hint = component.__class__
|
||||
|
||||
if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin):
|
||||
config = component.config
|
||||
else:
|
||||
config = None
|
||||
|
||||
load_spec = cls.decode_load_id(component._diffusers_load_id)
|
||||
|
||||
return cls(name=name, type_hint=type_hint, config=config, **load_spec)
|
||||
|
||||
@classmethod
|
||||
def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any:
|
||||
"""Create a ComponentSpec from a load_id string."""
|
||||
if load_id == "null":
|
||||
raise ValueError("Cannot create ComponentSpec from null load_id")
|
||||
|
||||
# Decode the load_id into a dictionary of loading fields
|
||||
load_fields = cls.decode_load_id(load_id)
|
||||
|
||||
# Create a new ComponentSpec instance with the decoded fields
|
||||
return cls(name=name, **load_fields)
|
||||
|
||||
@classmethod
|
||||
def loading_fields(cls) -> List[str]:
|
||||
"""
|
||||
Return the names of all loading‐related fields
|
||||
(i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this spec's pretrained load,
|
||||
composed of repo|subfolder|variant|revision (no empty segments).
|
||||
"""
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Decode a load_id string back into a dictionary of loading fields and values.
|
||||
|
||||
Args:
|
||||
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
|
||||
where None values are represented as "null"
|
||||
|
||||
Returns:
|
||||
Dict mapping loading field names to their values. e.g.
|
||||
{
|
||||
"repo": "path/to/repo",
|
||||
"subfolder": "subfolder",
|
||||
"variant": "variant",
|
||||
"revision": "revision"
|
||||
}
|
||||
If a segment value is "null", it's replaced with None.
|
||||
Returns None if load_id is "null" (indicating component not loaded from pretrained).
|
||||
"""
|
||||
|
||||
# Get all loading fields in order
|
||||
loading_fields = cls.loading_fields()
|
||||
result = {f: None for f in loading_fields}
|
||||
|
||||
if load_id == "null":
|
||||
return result
|
||||
|
||||
# Split the load_id
|
||||
parts = load_id.split("|")
|
||||
|
||||
# Map parts to loading fields by position
|
||||
for i, part in enumerate(parts):
|
||||
if i < len(loading_fields):
|
||||
# Convert "null" string back to None
|
||||
result[loading_fields[i]] = None if part == "null" else part
|
||||
|
||||
return result
|
||||
|
||||
# YiYi TODO: add validator
|
||||
def create(self, **kwargs) -> Any:
|
||||
"""Create the component using the preferred creation method."""
|
||||
|
||||
# from_pretrained creation
|
||||
if self.default_creation_method == "from_pretrained":
|
||||
return self.create_from_pretrained(**kwargs)
|
||||
elif self.default_creation_method == "from_config":
|
||||
# from_config creation
|
||||
return self.create_from_config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid creation method: {self.default_creation_method}")
|
||||
|
||||
def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
|
||||
"""Create component using from_config with config."""
|
||||
|
||||
if self.type_hint is None or not isinstance(self.type_hint, type):
|
||||
raise ValueError(
|
||||
f"`type_hint` is required when using from_config creation method."
|
||||
)
|
||||
|
||||
config = config or self.config or {}
|
||||
|
||||
if issubclass(self.type_hint, ConfigMixin):
|
||||
component = self.type_hint.from_config(config, **kwargs)
|
||||
else:
|
||||
signature_params = inspect.signature(self.type_hint.__init__).parameters
|
||||
init_kwargs = {}
|
||||
for k, v in config.items():
|
||||
if k in signature_params:
|
||||
init_kwargs[k] = v
|
||||
for k, v in kwargs.items():
|
||||
if k in signature_params:
|
||||
init_kwargs[k] = v
|
||||
component = self.type_hint(**init_kwargs)
|
||||
|
||||
component._diffusers_load_id = "null"
|
||||
if hasattr(component, "config"):
|
||||
self.config = component.config
|
||||
|
||||
return component
|
||||
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def create_from_pretrained(self, **kwargs) -> Any:
|
||||
"""Create component using from_pretrained."""
|
||||
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}")
|
||||
self.type_hint = component.__class__
|
||||
else:
|
||||
try:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}")
|
||||
|
||||
if repo != self.repo:
|
||||
self.repo = repo
|
||||
for k, v in passed_loading_kwargs.items():
|
||||
if v is not None:
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
return component
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
name: str
|
||||
default: Any
|
||||
description: Optional[str] = None
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
"""Specification for an output parameter."""
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
Format input parameters into a string representation, with required params first followed by optional ones.
|
||||
|
||||
Args:
|
||||
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
|
||||
|
||||
Returns:
|
||||
str: Formatted string of input parameters
|
||||
|
||||
Example:
|
||||
>>> inputs = [
|
||||
... InputParam(name="prompt", required=True),
|
||||
... InputParam(name="image", required=True),
|
||||
... InputParam(name="guidance_scale", required=False, default=7.5),
|
||||
... InputParam(name="num_inference_steps", required=False, default=50)
|
||||
... ]
|
||||
>>> format_inputs_short(inputs)
|
||||
'prompt, image, guidance_scale=7.5, num_inference_steps=50'
|
||||
"""
|
||||
required_inputs = [param for param in inputs if param.required]
|
||||
optional_inputs = [param for param in inputs if not param.required]
|
||||
|
||||
required_str = ", ".join(param.name for param in required_inputs)
|
||||
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
|
||||
|
||||
inputs_str = required_str
|
||||
if optional_str:
|
||||
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
|
||||
|
||||
return inputs_str
|
||||
|
||||
|
||||
def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs):
|
||||
"""
|
||||
Formats intermediate inputs and outputs of a block into a string representation.
|
||||
|
||||
Args:
|
||||
intermediates_inputs: List of intermediate input parameters
|
||||
required_intermediates_inputs: List of required intermediate input names
|
||||
intermediates_outputs: List of intermediate output parameters
|
||||
|
||||
Returns:
|
||||
str: Formatted string like:
|
||||
Intermediates:
|
||||
- inputs: Required(latents), dtype
|
||||
- modified: latents # variables that appear in both inputs and outputs
|
||||
- outputs: images # new outputs only
|
||||
"""
|
||||
# Handle inputs
|
||||
input_parts = []
|
||||
for inp in intermediates_inputs:
|
||||
if inp.name in required_intermediates_inputs:
|
||||
input_parts.append(f"Required({inp.name})")
|
||||
else:
|
||||
input_parts.append(inp.name)
|
||||
|
||||
# Handle modified variables (appear in both inputs and outputs)
|
||||
inputs_set = {inp.name for inp in intermediates_inputs}
|
||||
modified_parts = []
|
||||
new_output_parts = []
|
||||
|
||||
for out in intermediates_outputs:
|
||||
if out.name in inputs_set:
|
||||
modified_parts.append(out.name)
|
||||
else:
|
||||
new_output_parts.append(out.name)
|
||||
|
||||
result = []
|
||||
if input_parts:
|
||||
result.append(f" - inputs: {', '.join(input_parts)}")
|
||||
if modified_parts:
|
||||
result.append(f" - modified: {', '.join(modified_parts)}")
|
||||
if new_output_parts:
|
||||
result.append(f" - outputs: {', '.join(new_output_parts)}")
|
||||
|
||||
return "\n".join(result) if result else " (none)"
|
||||
|
||||
|
||||
def format_params(params, header="Args", indent_level=4, max_line_length=115):
|
||||
"""Format a list of InputParam or OutputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
params: List of InputParam or OutputParam objects to format
|
||||
header: Header text to use (e.g. "Args" or "Returns")
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all parameters
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
param_indent = " " * (indent_level + 4)
|
||||
desc_indent = " " * (indent_level + 8)
|
||||
formatted_params = []
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
|
||||
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
|
||||
return f"Union[{', '.join(types)}]"
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
def wrap_text(text, indent, max_length):
|
||||
"""Wrap text while preserving markdown links and maintaining indentation."""
|
||||
words = text.split()
|
||||
lines = []
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
word_length = len(word) + (1 if current_line else 0)
|
||||
|
||||
if current_line and current_length + word_length > max_length:
|
||||
lines.append(" ".join(current_line))
|
||||
current_line = [word]
|
||||
current_length = len(word)
|
||||
else:
|
||||
current_line.append(word)
|
||||
current_length += word_length
|
||||
|
||||
if current_line:
|
||||
lines.append(" ".join(current_line))
|
||||
|
||||
return f"\n{indent}".join(lines)
|
||||
|
||||
# Add the header
|
||||
formatted_params.append(f"{base_indent}{header}:")
|
||||
|
||||
for param in params:
|
||||
# Format parameter name and type
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
param_str = f"{param_indent}{param.name} (`{type_str}`"
|
||||
|
||||
# Add optional tag and default value if parameter is an InputParam and optional
|
||||
if hasattr(param, "required"):
|
||||
if not param.required:
|
||||
param_str += ", *optional*"
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to {param.default}"
|
||||
param_str += "):"
|
||||
|
||||
# Add description on a new line with additional indentation and wrapping
|
||||
if param.description:
|
||||
desc = re.sub(
|
||||
r'\[(.*?)\]\((https?://[^\s\)]+)\)',
|
||||
r'[\1](\2)',
|
||||
param.description
|
||||
)
|
||||
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
|
||||
param_str += f"\n{desc_indent}{wrapped_desc}"
|
||||
|
||||
formatted_params.append(param_str)
|
||||
|
||||
return "\n\n".join(formatted_params)
|
||||
|
||||
|
||||
def format_input_params(input_params, indent_level=4, max_line_length=115):
|
||||
"""Format a list of InputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
input_params: List of InputParam objects to format
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all input parameters
|
||||
"""
|
||||
return format_params(input_params, "Inputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_output_params(output_params, indent_level=4, max_line_length=115):
|
||||
"""Format a list of OutputParam objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
output_params: List of OutputParam objects to format
|
||||
indent_level: Number of spaces to indent each parameter line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all output parameters
|
||||
"""
|
||||
return format_params(output_params, "Outputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ComponentSpec objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
components: List of ComponentSpec objects to format
|
||||
indent_level: Number of spaces to indent each component line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
add_empty_lines: Whether to add empty lines between components (default: True)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all components
|
||||
"""
|
||||
if not components:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
component_indent = " " * (indent_level + 4)
|
||||
formatted_components = []
|
||||
|
||||
# Add the header
|
||||
formatted_components.append(f"{base_indent}Components:")
|
||||
if add_empty_lines:
|
||||
formatted_components.append("")
|
||||
|
||||
# Add each component with optional empty lines between them
|
||||
for i, component in enumerate(components):
|
||||
# Get type name, handling special cases
|
||||
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
|
||||
|
||||
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
|
||||
if component.description:
|
||||
component_desc += f": {component.description}"
|
||||
|
||||
# Get the loading fields dynamically
|
||||
loading_field_values = []
|
||||
for field_name in component.loading_fields():
|
||||
field_value = getattr(component, field_name)
|
||||
if field_value is not None:
|
||||
loading_field_values.append(f"{field_name}={field_value}")
|
||||
|
||||
# Add loading field information if available
|
||||
if loading_field_values:
|
||||
component_desc += f" [{', '.join(loading_field_values)}]"
|
||||
|
||||
formatted_components.append(component_desc)
|
||||
|
||||
# Add an empty line after each component except the last one
|
||||
if add_empty_lines and i < len(components) - 1:
|
||||
formatted_components.append("")
|
||||
|
||||
return "\n".join(formatted_components)
|
||||
|
||||
|
||||
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ConfigSpec objects into a readable string representation.
|
||||
|
||||
Args:
|
||||
configs: List of ConfigSpec objects to format
|
||||
indent_level: Number of spaces to indent each config line (default: 4)
|
||||
max_line_length: Maximum length for each line before wrapping (default: 115)
|
||||
add_empty_lines: Whether to add empty lines between configs (default: True)
|
||||
|
||||
Returns:
|
||||
A formatted string representing all configs
|
||||
"""
|
||||
if not configs:
|
||||
return ""
|
||||
|
||||
base_indent = " " * indent_level
|
||||
config_indent = " " * (indent_level + 4)
|
||||
formatted_configs = []
|
||||
|
||||
# Add the header
|
||||
formatted_configs.append(f"{base_indent}Configs:")
|
||||
if add_empty_lines:
|
||||
formatted_configs.append("")
|
||||
|
||||
# Add each config with optional empty lines between them
|
||||
for i, config in enumerate(configs):
|
||||
config_desc = f"{config_indent}{config.name} (default: {config.default})"
|
||||
if config.description:
|
||||
config_desc += f": {config.description}"
|
||||
formatted_configs.append(config_desc)
|
||||
|
||||
# Add an empty line after each config except the last one
|
||||
if add_empty_lines and i < len(configs) - 1:
|
||||
formatted_configs.append("")
|
||||
|
||||
return "\n".join(formatted_configs)
|
||||
|
||||
|
||||
def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None):
|
||||
"""
|
||||
Generates a formatted documentation string describing the pipeline block's parameters and structure.
|
||||
|
||||
Args:
|
||||
inputs: List of input parameters
|
||||
intermediates_inputs: List of intermediate input parameters
|
||||
outputs: List of output parameters
|
||||
description (str, *optional*): Description of the block
|
||||
class_name (str, *optional*): Name of the class to include in the documentation
|
||||
expected_components (List[ComponentSpec], *optional*): List of expected components
|
||||
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing information about components, configs, call parameters,
|
||||
intermediate inputs/outputs, and final outputs.
|
||||
"""
|
||||
output = ""
|
||||
|
||||
# Add class name if provided
|
||||
if class_name:
|
||||
output += f"class {class_name}\n\n"
|
||||
|
||||
# Add description
|
||||
if description:
|
||||
desc_lines = description.strip().split('\n')
|
||||
aligned_desc = '\n'.join(' ' + line for line in desc_lines)
|
||||
output += aligned_desc + "\n\n"
|
||||
|
||||
# Add components section if provided
|
||||
if expected_components and len(expected_components) > 0:
|
||||
components_str = format_components(expected_components, indent_level=2)
|
||||
output += components_str + "\n\n"
|
||||
|
||||
# Add configs section if provided
|
||||
if expected_configs and len(expected_configs) > 0:
|
||||
configs_str = format_configs(expected_configs, indent_level=2)
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
@@ -333,6 +333,20 @@ def maybe_raise_or_warn(
|
||||
)
|
||||
|
||||
|
||||
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
|
||||
def simple_get_class_obj(library_name, class_name):
|
||||
from diffusers import pipelines
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
return class_obj
|
||||
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
@@ -414,7 +428,7 @@ def _get_pipeline_class(
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
@@ -841,7 +855,10 @@ def _fetch_class_library_tuple(module):
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
if isinstance(not_compiled_module, type):
|
||||
class_name = not_compiled_module.__name__
|
||||
else:
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
|
||||
|
||||
@@ -1917,9 +1917,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
|
||||
}
|
||||
|
||||
optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else []
|
||||
missing_modules = (
|
||||
set(expected_modules)
|
||||
- set(pipeline._optional_components)
|
||||
- set(optional_components)
|
||||
- set(pipeline_kwargs.keys())
|
||||
- set(true_optional_modules)
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ else:
|
||||
"StableDiffusionXLDecodeLatentsStep",
|
||||
"StableDiffusionXLDenoiseStep",
|
||||
"StableDiffusionXLInputStep",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLModularLoader",
|
||||
"StableDiffusionXLPrepareAdditionalConditioningStep",
|
||||
"StableDiffusionXLPrepareLatentsStep",
|
||||
"StableDiffusionXLSetTimestepsStep",
|
||||
@@ -65,7 +65,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLDecodeLatentsStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularLoader,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
InputParam,
|
||||
@@ -56,8 +56,9 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...guiders import GuiderType, ClassifierFreeGuidance
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...configuration_utils import FrozenDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -182,9 +183,13 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
||||
ComponentSpec("feature_extractor", CLIPImageProcessor),
|
||||
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("guider", GuiderType),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -320,7 +325,11 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec("guider", GuiderType),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -645,7 +654,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -740,8 +753,16 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()),
|
||||
ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec(
|
||||
"mask_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
|
||||
@@ -1028,7 +1049,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1151,7 +1172,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1206,7 +1227,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1460,7 +1481,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1608,7 +1629,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1727,7 +1748,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [ConfigSpec("requires_aesthetics_score", default=False),]
|
||||
return [ConfigSpec("requires_aesthetics_score", False),]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -2062,8 +2083,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
]
|
||||
|
||||
@@ -2245,7 +2270,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
)
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
@@ -2316,11 +2341,15 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetModel),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -2626,7 +2655,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
)
|
||||
|
||||
# (5) Denoise loop
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
@@ -2733,9 +2762,17 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec("controlnet", ControlNetUnionModel),
|
||||
ComponentSpec("scheduler", KarrasDiffusionSchedulers),
|
||||
ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()),
|
||||
ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"),
|
||||
ComponentSpec(
|
||||
"control_image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3029,7 +3066,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"),
|
||||
)
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
guider_data = pipeline.guider.prepare_inputs(data)
|
||||
@@ -3136,7 +3173,11 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor())
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -3527,9 +3568,14 @@ SDXL_SUPPORTED_BLOCKS = {
|
||||
}
|
||||
|
||||
|
||||
# YiYi TODO: rename to components etc. and not inherit from ModularPipeline
|
||||
class StableDiffusionXLModularPipeline(
|
||||
ModularPipeline,
|
||||
# YiYi Notes: model specific components:
|
||||
## (1) it should inherit from ModularLoader
|
||||
## (2) acts like a container that holds components and configs
|
||||
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
|
||||
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
|
||||
## (5) how to use together with Components_manager?
|
||||
class StableDiffusionXLModularLoader(
|
||||
ModularLoader,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
@@ -1328,7 +1328,7 @@ class LDMSuperResolutionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModularPipeline(metaclass=DummyObject):
|
||||
class ModularLoader(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -2417,7 +2417,7 @@ class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
class StableDiffusionXLModularLoader(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user