mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
style
This commit is contained in:
@@ -232,8 +232,18 @@ class AutoOffloadStrategy:
|
||||
|
||||
|
||||
class ComponentsManager:
|
||||
_available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"]
|
||||
|
||||
_available_info_fields = [
|
||||
"model_id",
|
||||
"added_time",
|
||||
"collection",
|
||||
"class_name",
|
||||
"size_gb",
|
||||
"adapters",
|
||||
"has_hook",
|
||||
"execution_device",
|
||||
"ip_adapter",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
@@ -241,10 +251,16 @@ class ComponentsManager:
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None):
|
||||
def _lookup_ids(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
load_id: Optional[str] = None,
|
||||
components: Optional[OrderedDict] = None,
|
||||
):
|
||||
"""
|
||||
Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
|
||||
Returns a set of component_ids
|
||||
Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
|
||||
component_ids
|
||||
"""
|
||||
if components is None:
|
||||
components = self.components
|
||||
@@ -318,10 +334,14 @@ class ComponentsManager:
|
||||
if component_id not in self.collections[collection]:
|
||||
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
|
||||
for comp_id in comp_ids_in_collection:
|
||||
logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}")
|
||||
logger.warning(
|
||||
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
|
||||
)
|
||||
self.remove(comp_id)
|
||||
self.collections[collection].add(component_id)
|
||||
logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}")
|
||||
logger.info(
|
||||
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
||||
|
||||
@@ -379,40 +399,43 @@ class ComponentsManager:
|
||||
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
|
||||
If False, returns a dictionary with component IDs as keys
|
||||
return_dict_with_names:
|
||||
If True, returns a dictionary with component names as keys, throw an error if
|
||||
multiple components with the same name are found If False, returns a dictionary
|
||||
with component IDs as keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping component names to components if return_dict_with_names=True,
|
||||
or a dictionary mapping component IDs to components if return_dict_with_names=False
|
||||
Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
|
||||
component IDs to components if return_dict_with_names=False
|
||||
"""
|
||||
|
||||
# select components based on collection and load_id filters
|
||||
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
|
||||
components = {k: self.components[k] for k in selected_ids}
|
||||
|
||||
|
||||
def get_return_dict(components, return_dict_with_names):
|
||||
"""
|
||||
Create a dictionary mapping component names to components if return_dict_with_names=True,
|
||||
or a dictionary mapping component IDs to components if return_dict_with_names=False,
|
||||
throw an error if duplicate component names are found when return_dict_with_names=True
|
||||
Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
|
||||
mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
|
||||
names are found when return_dict_with_names=True
|
||||
"""
|
||||
if return_dict_with_names:
|
||||
dict_to_return = {}
|
||||
for comp_id, comp in components.items():
|
||||
comp_name = self._id_to_name(comp_id)
|
||||
if comp_name in dict_to_return:
|
||||
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
|
||||
raise ValueError(
|
||||
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
|
||||
)
|
||||
dict_to_return[comp_name] = comp
|
||||
return dict_to_return
|
||||
else:
|
||||
return components
|
||||
|
||||
|
||||
# if no names are provided, return the filtered components as it is
|
||||
if names is None:
|
||||
return get_return_dict(components, return_dict_with_names)
|
||||
|
||||
|
||||
# if names is not a string, raise an error
|
||||
elif not isinstance(names, str):
|
||||
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
|
||||
@@ -488,9 +511,7 @@ class ComponentsManager:
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
logger.info(
|
||||
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
|
||||
)
|
||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
@@ -584,8 +605,8 @@ class ComponentsManager:
|
||||
|
||||
# YiYi TODO: (1) add quantization info
|
||||
def get_model_info(
|
||||
self,
|
||||
component_id: str,
|
||||
self,
|
||||
component_id: str,
|
||||
fields: Optional[Union[str, List[str]]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
@@ -603,7 +624,7 @@ class ComponentsManager:
|
||||
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[component_id]
|
||||
|
||||
|
||||
# Validate fields if specified
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
@@ -662,7 +683,7 @@ class ComponentsManager:
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
else:
|
||||
return info
|
||||
|
||||
|
||||
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
|
||||
def __repr__(self):
|
||||
# Handle empty components case
|
||||
@@ -820,11 +841,9 @@ class ComponentsManager:
|
||||
load_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a single component by either:
|
||||
(1) searching name (pattern matching), collection, or load_id.
|
||||
(2) passing in a component_id
|
||||
Raises an error if multiple components match or none are found.
|
||||
support pattern matching for name
|
||||
Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
|
||||
a component_id Raises an error if multiple components match or none are found. support pattern matching for
|
||||
name
|
||||
|
||||
Args:
|
||||
component_id: Optional component ID to get
|
||||
@@ -841,7 +860,7 @@ class ComponentsManager:
|
||||
|
||||
if component_id is not None and (name is not None or collection is not None or load_id is not None):
|
||||
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
|
||||
|
||||
|
||||
# search by component_id
|
||||
if component_id is not None:
|
||||
if component_id not in self.components:
|
||||
@@ -857,7 +876,6 @@ class ComponentsManager:
|
||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||
|
||||
return next(iter(results.values()))
|
||||
|
||||
|
||||
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
@@ -869,7 +887,7 @@ class ComponentsManager:
|
||||
for name in names:
|
||||
ids.update(self._lookup_ids(name=name, collection=collection))
|
||||
return list(ids)
|
||||
|
||||
|
||||
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
|
||||
"""
|
||||
Get components by a list of IDs.
|
||||
@@ -881,7 +899,9 @@ class ComponentsManager:
|
||||
for comp_id, comp in components.items():
|
||||
comp_name = self._id_to_name(comp_id)
|
||||
if comp_name in dict_to_return:
|
||||
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
|
||||
raise ValueError(
|
||||
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
|
||||
)
|
||||
dict_to_return[comp_name] = comp
|
||||
return dict_to_return
|
||||
else:
|
||||
@@ -894,6 +914,7 @@ class ComponentsManager:
|
||||
ids = self.get_ids(names, collection)
|
||||
return self.get_components_by_ids(ids)
|
||||
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
|
||||
@@ -1849,23 +1849,30 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
return module.dtype
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
@property
|
||||
def null_component_names(self) -> List[str]:
|
||||
return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
|
||||
|
||||
|
||||
@property
|
||||
def component_names(self) -> List[str]:
|
||||
return list(self.components.keys())
|
||||
|
||||
|
||||
@property
|
||||
def pretrained_component_names(self) -> List[str]:
|
||||
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"]
|
||||
|
||||
return [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
|
||||
@property
|
||||
def config_component_names(self) -> List[str]:
|
||||
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"]
|
||||
|
||||
return [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_config"
|
||||
]
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
@@ -2430,9 +2437,13 @@ class ModularPipeline:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"]
|
||||
names = [
|
||||
name
|
||||
for name in self.loader._component_specs.keys()
|
||||
if self.loader._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
self.loader.load(names=names, **kwargs)
|
||||
|
||||
|
||||
def load_components(self, names: Union[List[str], str], **kwargs):
|
||||
self.loader.load(names=names, **kwargs)
|
||||
|
||||
|
||||
@@ -23,18 +23,18 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
|
||||
_import_structure["modular_blocks_presets"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"IP_ADAPTER_BLOCKS",
|
||||
"ALL_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLAutoControlnetStep",
|
||||
"StableDiffusionXLAutoDecodeStep",
|
||||
"StableDiffusionXLAutoIPAdapterStep",
|
||||
"StableDiffusionXLAutoVaeEncoderStep",
|
||||
"StableDiffusionXLAutoControlnetStep",
|
||||
]
|
||||
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
|
||||
|
||||
@@ -49,18 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks_presets import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
IP_ADAPTER_BLOCKS,
|
||||
ALL_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLAutoControlnetStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLAutoControlnetStep,
|
||||
)
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
else:
|
||||
|
||||
@@ -76,9 +76,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
|
||||
)
|
||||
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
|
||||
@@ -44,7 +44,6 @@ class StableDiffusionXLModularLoader(
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
ModularIPAdapterMixin,
|
||||
):
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
@@ -52,8 +51,7 @@ class StableDiffusionXLModularLoader(
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
default_sample_size = 128
|
||||
|
||||
Reference in New Issue
Block a user