diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 3394c67cb0..f2a2b0b080 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -232,8 +232,18 @@ class AutoOffloadStrategy: class ComponentsManager: - _available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"] - + _available_info_fields = [ + "model_id", + "added_time", + "collection", + "class_name", + "size_gb", + "adapters", + "has_hook", + "execution_device", + "ip_adapter", + ] + def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added @@ -241,10 +251,16 @@ class ComponentsManager: self.model_hooks = None self._auto_offload_enabled = False - def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None): + def _lookup_ids( + self, + name: Optional[str] = None, + collection: Optional[str] = None, + load_id: Optional[str] = None, + components: Optional[OrderedDict] = None, + ): """ - Lookup component_ids by name, collection, or load_id. Does not support pattern matching. - Returns a set of component_ids + Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of + component_ids """ if components is None: components = self.components @@ -318,10 +334,14 @@ class ComponentsManager: if component_id not in self.collections[collection]: comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) for comp_id in comp_ids_in_collection: - logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}") + logger.warning( + f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" + ) self.remove(comp_id) self.collections[collection].add(component_id) - logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}") + logger.info( + f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" + ) else: logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") @@ -379,40 +399,43 @@ class ComponentsManager: - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" collection: Optional collection to filter by load_id: Optional load_id to filter by - return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found - If False, returns a dictionary with component IDs as keys + return_dict_with_names: + If True, returns a dictionary with component names as keys, throw an error if + multiple components with the same name are found If False, returns a dictionary + with component IDs as keys Returns: - Dictionary mapping component names to components if return_dict_with_names=True, - or a dictionary mapping component IDs to components if return_dict_with_names=False + Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping + component IDs to components if return_dict_with_names=False """ # select components based on collection and load_id filters selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} - + def get_return_dict(components, return_dict_with_names): """ - Create a dictionary mapping component names to components if return_dict_with_names=True, - or a dictionary mapping component IDs to components if return_dict_with_names=False, - throw an error if duplicate component names are found when return_dict_with_names=True + Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary + mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component + names are found when return_dict_with_names=True """ if return_dict_with_names: dict_to_return = {} for comp_id, comp in components.items(): comp_name = self._id_to_name(comp_id) if comp_name in dict_to_return: - raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) dict_to_return[comp_name] = comp return dict_to_return else: return components - # if no names are provided, return the filtered components as it is if names is None: return get_return_dict(components, return_dict_with_names) - + # if names is not a string, raise an error elif not isinstance(names, str): raise ValueError(f"Invalid type for `names: {type(names)}, only support string") @@ -488,9 +511,7 @@ class ComponentsManager: } if is_not_pattern: - logger.info( - f"Getting all components except those with base name '{names}': {list(matches.keys())}" - ) + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") @@ -584,8 +605,8 @@ class ComponentsManager: # YiYi TODO: (1) add quantization info def get_model_info( - self, - component_id: str, + self, + component_id: str, fields: Optional[Union[str, List[str]]] = None, ) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -603,7 +624,7 @@ class ComponentsManager: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] - + # Validate fields if specified if fields is not None: if isinstance(fields, str): @@ -662,7 +683,7 @@ class ComponentsManager: return {k: v for k, v in info.items() if k in fields} else: return info - + # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table def __repr__(self): # Handle empty components case @@ -820,11 +841,9 @@ class ComponentsManager: load_id: Optional[str] = None, ) -> Any: """ - Get a single component by either: - (1) searching name (pattern matching), collection, or load_id. - (2) passing in a component_id - Raises an error if multiple components match or none are found. - support pattern matching for name + Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in + a component_id Raises an error if multiple components match or none are found. support pattern matching for + name Args: component_id: Optional component ID to get @@ -841,7 +860,7 @@ class ComponentsManager: if component_id is not None and (name is not None or collection is not None or load_id is not None): raise ValueError("If searching by component_id, do not pass name, collection, or load_id") - + # search by component_id if component_id is not None: if component_id not in self.components: @@ -857,7 +876,6 @@ class ComponentsManager: raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") return next(iter(results.values())) - def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None): """ @@ -869,7 +887,7 @@ class ComponentsManager: for name in names: ids.update(self._lookup_ids(name=name, collection=collection)) return list(ids) - + def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True): """ Get components by a list of IDs. @@ -881,7 +899,9 @@ class ComponentsManager: for comp_id, comp in components.items(): comp_name = self._id_to_name(comp_id) if comp_name in dict_to_return: - raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys") + raise ValueError( + f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" + ) dict_to_return[comp_name] = comp return dict_to_return else: @@ -894,6 +914,7 @@ class ComponentsManager: ids = self.get_ids(names, collection) return self.get_components_by_ids(ids) + def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 7bb3936339..6bdd2f3f36 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1849,23 +1849,30 @@ class ModularLoader(ConfigMixin, PushToHubMixin): return module.dtype return torch.float32 - + @property def null_component_names(self) -> List[str]: return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None] - + @property def component_names(self) -> List[str]: return list(self.components.keys()) - + @property def pretrained_component_names(self) -> List[str]: - return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] - + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + @property def config_component_names(self) -> List[str]: - return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"] - + return [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_config" + ] @property def components(self) -> Dict[str, Any]: @@ -2430,9 +2437,13 @@ class ModularPipeline: raise ValueError(f"Output '{output}' is not a valid output type") def load_default_components(self, **kwargs): - names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"] + names = [ + name + for name in self.loader._component_specs.keys() + if self.loader._component_specs[name].default_creation_method == "from_pretrained" + ] self.loader.load(names=names, **kwargs) - + def load_components(self, names: Union[List[str], str], **kwargs): self.loader.load(names=names, **kwargs) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 9adb052795..95461cfc23 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -23,18 +23,18 @@ except OptionalDependencyNotAvailable: else: _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] _import_structure["modular_blocks_presets"] = [ + "ALL_BLOCKS", "AUTO_BLOCKS", "CONTROLNET_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "IP_ADAPTER_BLOCKS", - "ALL_BLOCKS", "TEXT2IMAGE_BLOCKS", "StableDiffusionXLAutoBlocks", + "StableDiffusionXLAutoControlnetStep", "StableDiffusionXLAutoDecodeStep", "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", - "StableDiffusionXLAutoControlnetStep", ] _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] @@ -49,18 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLTextEncoderStep, ) from .modular_blocks_presets import ( + ALL_BLOCKS, AUTO_BLOCKS, CONTROLNET_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, IP_ADAPTER_BLOCKS, - ALL_BLOCKS, TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoBlocks, + StableDiffusionXLAutoControlnetStep, StableDiffusionXLAutoDecodeStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLAutoControlnetStep, ) from .modular_loader import StableDiffusionXLModularLoader else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py index d28eb5387a..fee955411c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py @@ -76,9 +76,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): @property def description(self): - return ( - "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" - ) + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" # before_denoise: text2img diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 34222444da..c161c9290f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -44,7 +44,6 @@ class StableDiffusionXLModularLoader( StableDiffusionXLLoraLoaderMixin, ModularIPAdapterMixin, ): - @property def default_height(self): return self.default_sample_size * self.vae_scale_factor @@ -52,8 +51,7 @@ class StableDiffusionXLModularLoader( @property def default_width(self): return self.default_sample_size * self.vae_scale_factor - - + @property def default_sample_size(self): default_sample_size = 128