1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-06-28 12:50:11 +02:00
parent 58dbe0c29e
commit 49ea4d1bf5
5 changed files with 81 additions and 53 deletions

View File

@@ -232,8 +232,18 @@ class AutoOffloadStrategy:
class ComponentsManager:
_available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"]
_available_info_fields = [
"model_id",
"added_time",
"collection",
"class_name",
"size_gb",
"adapters",
"has_hook",
"execution_device",
"ip_adapter",
]
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
@@ -241,10 +251,16 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None):
def _lookup_ids(
self,
name: Optional[str] = None,
collection: Optional[str] = None,
load_id: Optional[str] = None,
components: Optional[OrderedDict] = None,
):
"""
Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
Returns a set of component_ids
Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
component_ids
"""
if components is None:
components = self.components
@@ -318,10 +334,14 @@ class ComponentsManager:
if component_id not in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}")
logger.warning(
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
)
self.remove(comp_id)
self.collections[collection].add(component_id)
logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}")
logger.info(
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
)
else:
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
@@ -379,40 +399,43 @@ class ComponentsManager:
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
collection: Optional collection to filter by
load_id: Optional load_id to filter by
return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
If False, returns a dictionary with component IDs as keys
return_dict_with_names:
If True, returns a dictionary with component names as keys, throw an error if
multiple components with the same name are found If False, returns a dictionary
with component IDs as keys
Returns:
Dictionary mapping component names to components if return_dict_with_names=True,
or a dictionary mapping component IDs to components if return_dict_with_names=False
Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
component IDs to components if return_dict_with_names=False
"""
# select components based on collection and load_id filters
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
def get_return_dict(components, return_dict_with_names):
"""
Create a dictionary mapping component names to components if return_dict_with_names=True,
or a dictionary mapping component IDs to components if return_dict_with_names=False,
throw an error if duplicate component names are found when return_dict_with_names=True
Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
names are found when return_dict_with_names=True
"""
if return_dict_with_names:
dict_to_return = {}
for comp_id, comp in components.items():
comp_name = self._id_to_name(comp_id)
if comp_name in dict_to_return:
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
raise ValueError(
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
)
dict_to_return[comp_name] = comp
return dict_to_return
else:
return components
# if no names are provided, return the filtered components as it is
if names is None:
return get_return_dict(components, return_dict_with_names)
# if names is not a string, raise an error
elif not isinstance(names, str):
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
@@ -488,9 +511,7 @@ class ComponentsManager:
}
if is_not_pattern:
logger.info(
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
)
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
@@ -584,8 +605,8 @@ class ComponentsManager:
# YiYi TODO: (1) add quantization info
def get_model_info(
self,
component_id: str,
self,
component_id: str,
fields: Optional[Union[str, List[str]]] = None,
) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
@@ -603,7 +624,7 @@ class ComponentsManager:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[component_id]
# Validate fields if specified
if fields is not None:
if isinstance(fields, str):
@@ -662,7 +683,7 @@ class ComponentsManager:
return {k: v for k, v in info.items() if k in fields}
else:
return info
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
def __repr__(self):
# Handle empty components case
@@ -820,11 +841,9 @@ class ComponentsManager:
load_id: Optional[str] = None,
) -> Any:
"""
Get a single component by either:
(1) searching name (pattern matching), collection, or load_id.
(2) passing in a component_id
Raises an error if multiple components match or none are found.
support pattern matching for name
Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
a component_id Raises an error if multiple components match or none are found. support pattern matching for
name
Args:
component_id: Optional component ID to get
@@ -841,7 +860,7 @@ class ComponentsManager:
if component_id is not None and (name is not None or collection is not None or load_id is not None):
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
# search by component_id
if component_id is not None:
if component_id not in self.components:
@@ -857,7 +876,6 @@ class ComponentsManager:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
"""
@@ -869,7 +887,7 @@ class ComponentsManager:
for name in names:
ids.update(self._lookup_ids(name=name, collection=collection))
return list(ids)
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
"""
Get components by a list of IDs.
@@ -881,7 +899,9 @@ class ComponentsManager:
for comp_id, comp in components.items():
comp_name = self._id_to_name(comp_id)
if comp_name in dict_to_return:
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
raise ValueError(
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
)
dict_to_return[comp_name] = comp
return dict_to_return
else:
@@ -894,6 +914,7 @@ class ComponentsManager:
ids = self.get_ids(names, collection)
return self.get_components_by_ids(ids)
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.

View File

@@ -1849,23 +1849,30 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
return module.dtype
return torch.float32
@property
def null_component_names(self) -> List[str]:
return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
@property
def component_names(self) -> List[str]:
return list(self.components.keys())
@property
def pretrained_component_names(self) -> List[str]:
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"]
return [
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
]
@property
def config_component_names(self) -> List[str]:
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"]
return [
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_config"
]
@property
def components(self) -> Dict[str, Any]:
@@ -2430,9 +2437,13 @@ class ModularPipeline:
raise ValueError(f"Output '{output}' is not a valid output type")
def load_default_components(self, **kwargs):
names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"]
names = [
name
for name in self.loader._component_specs.keys()
if self.loader._component_specs[name].default_creation_method == "from_pretrained"
]
self.loader.load(names=names, **kwargs)
def load_components(self, names: Union[List[str], str], **kwargs):
self.loader.load(names=names, **kwargs)

View File

@@ -23,18 +23,18 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
_import_structure["modular_blocks_presets"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"IP_ADAPTER_BLOCKS",
"ALL_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLAutoControlnetStep",
"StableDiffusionXLAutoDecodeStep",
"StableDiffusionXLAutoIPAdapterStep",
"StableDiffusionXLAutoVaeEncoderStep",
"StableDiffusionXLAutoControlnetStep",
]
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
@@ -49,18 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLTextEncoderStep,
)
from .modular_blocks_presets import (
ALL_BLOCKS,
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
ALL_BLOCKS,
TEXT2IMAGE_BLOCKS,
StableDiffusionXLAutoBlocks,
StableDiffusionXLAutoControlnetStep,
StableDiffusionXLAutoDecodeStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLAutoControlnetStep,
)
from .modular_loader import StableDiffusionXLModularLoader
else:

View File

@@ -76,9 +76,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
@property
def description(self):
return (
"Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
)
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
# before_denoise: text2img

View File

@@ -44,7 +44,6 @@ class StableDiffusionXLModularLoader(
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@@ -52,8 +51,7 @@ class StableDiffusionXLModularLoader(
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
default_sample_size = 128