From 7d2a633e02724ded6960d9cd4e5515c366e82698 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 25 Jun 2025 11:26:36 +0200 Subject: [PATCH] style --- src/diffusers/__init__.py | 24 +- src/diffusers/commands/custom_blocks.py | 11 +- src/diffusers/commands/diffusers_cli.py | 2 +- .../guiders/adaptive_projected_guidance.py | 23 +- src/diffusers/guiders/auto_guidance.py | 21 +- .../guiders/classifier_free_guidance.py | 15 +- .../classifier_free_zero_star_guidance.py | 17 +- src/diffusers/guiders/guider_utils.py | 16 +- src/diffusers/guiders/skip_layer_guidance.py | 25 +- .../guiders/smoothed_energy_guidance.py | 25 +- .../tangential_classifier_free_guidance.py | 19 +- src/diffusers/hooks/layer_skip.py | 15 +- .../hooks/smoothed_energy_guidance_utils.py | 20 +- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/components_manager.py | 175 ++-- .../modular_pipelines/modular_pipeline.py | 335 ++++---- .../modular_pipeline_utils.py | 169 ++-- src/diffusers/modular_pipelines/node_utils.py | 97 +-- .../stable_diffusion_xl/__init__.py | 21 +- .../stable_diffusion_xl/before_denoise.py | 249 +++--- .../stable_diffusion_xl/decoders.py | 28 +- .../stable_diffusion_xl/denoise.py | 782 ++---------------- .../stable_diffusion_xl/encoders.py | 119 ++- .../modular_block_mappings.py | 42 +- .../stable_diffusion_xl/modular_loader.py | 9 +- .../modular_pipeline_presets.py | 11 +- src/diffusers/utils/dynamic_modules_utils.py | 4 +- 28 files changed, 828 insertions(+), 1452 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 164ee216f3..18d90be500 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -794,8 +794,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, - apply_layer_skip, apply_faster_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( @@ -875,6 +875,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: WanTransformer3DModel, WanVACETransformer3DModel, ) + from .modular_pipelines import ( + ComponentsManager, + ComponentSpec, + ModularLoader, + ModularPipeline, + ModularPipelineBlocks, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -907,13 +914,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ScoreSdeVePipeline, StableDiffusionMixin, ) - from .modular_pipelines import ( - ModularLoader, - ModularPipeline, - ModularPipelineBlocks, - ComponentSpec, - ComponentsManager, - ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -978,6 +978,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) from .pipelines import ( AllegroPipeline, AltDiffusionImg2ImgPipeline, @@ -1182,10 +1186,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .modular_pipelines import ( - StableDiffusionXLAutoPipeline, - StableDiffusionXLModularLoader, - ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index d2f2de3a8f..f532e8b775 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -18,10 +18,11 @@ Usage example: """ import ast -from argparse import ArgumentParser, Namespace -from pathlib import Path import importlib.util import os +from argparse import ArgumentParser, Namespace +from pathlib import Path + from ..utils import logging from . import BaseDiffusersCLICommand @@ -57,7 +58,7 @@ class CustomBlocksCommand(BaseDiffusersCLICommand): # determine the block to be saved. out = self._get_class_names(self.block_module_name) classes_found = list({cls for cls, _ in out}) - + if self.block_class_name is not None: child_class, parent_class = self._choose_block(out, self.block_class_name) if child_class is None and parent_class is None: @@ -125,9 +126,9 @@ class CustomBlocksCommand(BaseDiffusersCLICommand): val = self._get_base_name(node.value) return f"{val}.{node.attr}" if val else node.attr return None - + def _create_automap(self, parent_class, child_class): module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] auto_map = {f"{parent_class}": f"{module}.{child_class}"} return {"auto_map": auto_map} - + diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py index f291303d1e..a27ac24f2a 100644 --- a/src/diffusers/commands/diffusers_cli.py +++ b/src/diffusers/commands/diffusers_cli.py @@ -15,9 +15,9 @@ from argparse import ArgumentParser +from .custom_blocks import CustomBlocksCommand from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand -from .custom_blocks import CustomBlocksCommand def main(): diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index ef2f3f2c84..f1a6096c4d 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -74,10 +75,10 @@ class AdaptiveProjectedGuidance(BaseGuidance): self.momentum_buffer = None def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) @@ -123,19 +124,19 @@ class AdaptiveProjectedGuidance(BaseGuidance): def _is_apg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close @@ -160,25 +161,25 @@ def normalized_guidance( ): diff = pred_cond - pred_uncond dim = [-i for i in range(1, len(diff.shape))] - + if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average - + if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor - + v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel - + pred = pred_cond if use_original_formulation else pred_uncond pred = pred + guidance_scale * normalized_update - + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 791cc582ad..83120c20ce 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -113,18 +114,18 @@ class AutoGuidance(BaseGuidance): if self._is_ag_enabled() and self.is_unconditional: for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): _apply_layer_skip_hook(denoiser, config, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: if self._is_ag_enabled() and self.is_unconditional: for name in self._auto_guidance_hook_names: registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -144,9 +145,9 @@ class AutoGuidance(BaseGuidance): if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - + return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 @@ -161,17 +162,17 @@ class AutoGuidance(BaseGuidance): def _is_ag_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index a459e51cd0..faeba09711 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -74,12 +75,12 @@ class ClassifierFreeGuidance(BaseGuidance): self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -116,17 +117,17 @@ class ClassifierFreeGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index a722f26050..b4dee9295a 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -72,12 +73,12 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): self.zero_init_steps = zero_init_steps self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -106,7 +107,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 @@ -121,19 +122,19 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index e8e873f5c8..87109eb048 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -58,10 +58,10 @@ class BaseGuidance: def disable(self): self._enabled = False - + def enable(self): self._enabled = True - + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps @@ -104,14 +104,14 @@ class BaseGuidance: f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." ) self._input_fields = kwargs - + def prepare_models(self, denoiser: torch.nn.Module) -> None: """ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in subclasses to implement specific model preparation logic. """ self._count_prepared += 1 - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: """ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in @@ -119,7 +119,7 @@ class BaseGuidance: modifications made during `prepare_models`. """ pass - + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") @@ -139,15 +139,15 @@ class BaseGuidance: @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") - + @property def is_unconditional(self) -> bool: return not self.is_conditional - + @property def num_conditions(self) -> int: raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") - + @classmethod def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": """ diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 7c19f6391f..ffe00ea7db 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -148,19 +149,19 @@ class SkipLayerGuidance(BaseGuidance): if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module) -> None: if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -204,7 +205,7 @@ class SkipLayerGuidance(BaseGuidance): pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 or self._count_prepared == 3 @@ -221,31 +222,31 @@ class SkipLayerGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def _is_slg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step - + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) - + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 3986da913f..ab21b6d952 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -21,6 +21,7 @@ from ..hooks import HookRegistry from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -141,19 +142,19 @@ class SmoothedEnergyGuidance(BaseGuidance): if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) - + def cleanup_models(self, denoiser: torch.nn.Module): if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -197,7 +198,7 @@ class SmoothedEnergyGuidance(BaseGuidance): pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred, {} - + @property def is_conditional(self) -> bool: return self._count_prepared == 1 or self._count_prepared == 3 @@ -214,31 +215,31 @@ class SmoothedEnergyGuidance(BaseGuidance): def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def _is_seg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step - + is_zero = math.isclose(self.seg_guidance_scale, 0.0) - + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 017693fd9f..fdcdaf8dcb 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,12 +13,13 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from .guider_utils import BaseGuidance, rescale_noise_cfg + if TYPE_CHECKING: from ..modular_pipelines.modular_pipeline import BlockState @@ -63,10 +64,10 @@ class TangentialClassifierFreeGuidance(BaseGuidance): self.use_original_formulation = use_original_formulation def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: - + if input_fields is None: input_fields = self._input_fields - + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): @@ -101,24 +102,24 @@ class TangentialClassifierFreeGuidance(BaseGuidance): def _is_tcfg_enabled(self) -> bool: if not self._enabled: return False - + is_within_range = True if self._num_inference_steps is not None: skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step - + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) - + return is_within_range and not is_close def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: - cond_dtype = pred_cond.dtype + cond_dtype = pred_cond.dtype preds = torch.stack([pred_cond, pred_uncond], dim=1).float() preds = preds.flatten(2) U, S, Vh = torch.linalg.svd(preds, full_matrices=False) @@ -129,9 +130,9 @@ def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guid x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) x_Vh_V = torch.matmul(x_Vh, Vh_modified) pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) - + pred = pred_cond if use_original_formulation else pred_uncond shift = pred_cond - pred_uncond pred = pred + guidance_scale * shift - + return pred diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 65a99464ba..6b847271c9 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -20,7 +20,12 @@ import torch from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn +from ._common import ( + _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, + _ATTENTION_CLASSES, + _FEEDFORWARD_CLASSES, + _get_submodule_from_fqn, +) from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -198,15 +203,15 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam for i, block in enumerate(transformer_blocks): if i not in config.indices: continue - + blocks_found = True - + if config.skip_attention and config.skip_ff: logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") registry = HookRegistry.check_if_exists_or_initialize(block) hook = TransformerBlockSkipHook(config.dropout) registry.register_hook(hook, name) - + elif config.skip_attention or config.skip_attention_scores: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: @@ -215,7 +220,7 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) registry.register_hook(hook, name) - + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index f0366e2988..353ce72894 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional import torch import torch.nn.functional as F @@ -67,7 +67,7 @@ class SmoothedEnergyGuidanceHook(ModelHook): def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK - + if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: if hasattr(module, identifier): @@ -78,18 +78,18 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " "`fqn` (fully qualified name) that identifies a stack of transformer blocks." ) - + if config._query_proj_identifiers is None: config._query_proj_identifiers = ["to_q"] - + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) blocks_found = False for i, block in enumerate(transformer_blocks): if i not in config.indices: continue - + blocks_found = True - + for submodule_name, submodule in block.named_modules(): if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: continue @@ -103,7 +103,7 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth registry = HookRegistry.check_if_exists_or_initialize(query_proj) hook = SmoothedEnergyGuidanceHook(blur_sigma) registry.register_hook(hook, name) - + if not blocks_found: raise ValueError( f"Could not find any transformer blocks matching the provided indices {config.indices} and " @@ -124,7 +124,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma in the future without warning or guarantee of reproducibility. """ assert query.ndim == 3 - + is_inf = sigma > sigma_threshold_inf batch_size, seq_len, embed_dim = query.shape @@ -133,7 +133,7 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = query[:, :num_square_tokens, :] query_slice = query_slice.permute(0, 2, 1) query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) - + if is_inf: kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) kernel_size_half = (kernel_size - 1) / 2 @@ -154,5 +154,5 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.permute(0, 2, 1) query[:, :num_square_tokens, :] = query_slice.clone() - + return query diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index a5f5e6376b..335d7e623f 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -102,8 +102,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .ip_adapter import ( FluxIPAdapterMixin, IPAdapterMixin, - SD3IPAdapterMixin, ModularIPAdapterMixin, + SD3IPAdapterMixin, ) from .lora_pipeline import ( AmusedLoraLoaderMixin, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 4499634d9f..f6e398268c 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -49,13 +49,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .components_manager import ComponentsManager from .modular_pipeline import ( AutoPipelineBlocks, BlockState, LoopSequentialPipelineBlocks, ModularLoader, - ModularPipelineBlocks, ModularPipeline, + ModularPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, @@ -70,7 +71,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLAutoPipeline, StableDiffusionXLModularLoader, ) - from .components_manager import ComponentsManager else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index bdc24d474a..3f22fa7115 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -12,24 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import time +import uuid from collections import OrderedDict from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy +from typing import Any, Dict, List, Optional, Union import torch -import time -from dataclasses import dataclass from ..utils import ( is_accelerate_available, logging, ) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec - - -import uuid if is_accelerate_available(): @@ -237,12 +232,12 @@ class AutoOffloadStrategy: class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): """ Lookup component_ids by name, collection, or load_id. @@ -251,7 +246,7 @@ class ComponentsManager: components = self.components if name: - ids_by_name = set() + ids_by_name = set() for component_id, component in components.items(): comp_name = self._id_to_name(component_id) if comp_name == name: @@ -272,16 +267,16 @@ class ComponentsManager: ids_by_load_id.add(name) else: ids_by_load_id = set(components.keys()) - + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) return ids - + @staticmethod def _id_to_name(component_id: str): return "_".join(component_id.split("_")[:-1]) - + def add(self, name, component, collection: Optional[str] = None): - + component_id = f"{name}_{uuid.uuid4()}" # check for duplicated components @@ -305,7 +300,7 @@ class ComponentsManager: if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] - + if components_with_same_load_id: existing = ", ".join(components_with_same_load_id) logger.warning( @@ -320,7 +315,7 @@ class ComponentsManager: if collection: if collection not in self.collections: self.collections[collection] = set() - if not component_id in self.collections[collection]: + if component_id not in self.collections[collection]: comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) for comp_id in comp_ids_in_collection: logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") @@ -331,8 +326,8 @@ class ComponentsManager: logger.info(f"Added component '{name}' as '{component_id}'") if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - + self.enable_auto_cpu_offload(self._auto_offload_device) + return component_id @@ -341,14 +336,14 @@ class ComponentsManager: if component_id not in self.components: logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - + component = self.components.pop(component_id) self.added_time.pop(component_id) for collection in self.collections: if component_id in self.collections[collection]: self.collections[collection].remove(component_id) - + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) else: @@ -386,7 +381,7 @@ class ComponentsManager: Dictionary mapping component IDs to components, or list of (base_name, component) tuples if as_name_component_tuples=True """ - + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) components = {k: self.components[k] for k in selected_ids} @@ -397,16 +392,16 @@ class ComponentsManager: if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return component_id - + if names is None: if as_name_component_tuples: return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] else: return components - + # Create mapping from component_id to base_name for all components base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - + def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. @@ -417,124 +412,124 @@ class ComponentsManager: exact_match: If True, only exact matches to base_name are considered """ base_name = base_names[component_id] - + # Exact match with base name if exact_match: return pattern == base_name - + # Prefix match (ends with *) elif pattern.endswith('*'): prefix = pattern[:-1] return base_name.startswith(prefix) - + # Contains match (starts with *) elif pattern.startswith('*'): search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] return search in base_name - + # Exact match (no wildcards) else: return pattern == base_name - + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') if is_not_pattern: names = names[1:] # Remove the ! prefix - + # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') matches = {} - + for comp_id, comp in components.items(): # For OR patterns with exact names (no wildcards), we do exact matching on base names exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - + # Check if any of the terms match this component should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - + # Flip the decision if this is a NOT pattern if is_not_pattern: should_include = not should_include - + if should_include: matches[comp_id] = comp - + log_msg = "NOT " if is_not_pattern else "" match_type = "exactly matching" if exact_match else "matching any of patterns" logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - + # Try exact match with a base name elif any(names == base_name for base_name in base_names.values()): # Find all components with this base name matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (base_names[comp_id] == names) != is_not_pattern } - + if is_not_pattern: logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - + # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") else: logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - + # Contains match (starts with *) elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - + # Substring match (no wildcards, but not an exact component name) elif any(names in base_name for base_name in base_names.values()): matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (names in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - + else: raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - + if not matches: raise ValueError(f"No components found matching pattern '{names}'") - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] else: return matches - + elif isinstance(names, list): results = {} for name in names: result = self.get(name, collection, load_id, as_name_component_tuples=False) results.update(result) - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in results.items()] else: return results - + else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -595,14 +590,14 @@ class ComponentsManager: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") component = self.components[component_id] - + # Build complete info dict first info = { "model_id": component_id, "added_time": self.added_time[component_id], "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } - + # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): # Check for hook information @@ -610,7 +605,7 @@ class ComponentsManager: execution_device = None if has_hook and hasattr(component._hf_hook, "execution_device"): execution_device = component._hf_hook.execution_device - + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), @@ -631,8 +626,8 @@ class ComponentsManager: if any("IPAdapter" in ptype for ptype in processor_types): # Then get scales only from IP-Adapter processors scales = { - k: v.scale - for k, v in processors.items() + k: v.scale + for k, v in processors.items() if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ } if scales: @@ -646,7 +641,7 @@ class ComponentsManager: else: # List of fields requested, return dict with just those fields return {k: v for k, v in info.items() if k in fields} - + return info def __repr__(self): @@ -659,13 +654,13 @@ class ComponentsManager: if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return name - + # Extract load_id if available def get_load_id(component): if hasattr(component, "_diffusers_load_id"): return component._diffusers_load_id return "N/A" - + # Format device info compactly def format_device(component, info): if not info["has_hook"]: @@ -674,18 +669,18 @@ class ComponentsManager: device = str(getattr(component, 'device', 'N/A')) exec_device = str(info['execution_device'] or 'N/A') return f"{device}({exec_device})" - + # Get all simple names to calculate width simple_names = [get_simple_name(id) for id in self.components.keys()] - + # Get max length of load_ids for models load_ids = [ - get_load_id(component) - for component in self.components.values() + get_load_id(component) + for component in self.components.values() if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - + # Get all collections for each component component_collections = {} for name in self.components.keys(): @@ -695,11 +690,11 @@ class ComponentsManager: component_collections[name].append(coll) if not component_collections[name]: component_collections[name] = ["N/A"] - + # Find the maximum collection name length all_collections = [coll for colls in component_collections.values() for coll in colls] max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 - + col_widths = { "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), @@ -736,21 +731,21 @@ class ComponentsManager: device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) - + # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" - + # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" - + output += dash_line # Other components section @@ -766,17 +761,17 @@ class ComponentsManager: for name, component in others.items(): info = self.get_model_info(name) simple_name = get_simple_name(name) - + # Print first collection on the main line first_collection = component_collections[name][0] if component_collections[name] else "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" - + # Print additional collections on separate lines if they exist for i in range(1, len(component_collections[name])): collection = component_collections[name][i] output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" - + output += dash_line # Add additional component info @@ -789,8 +784,8 @@ class ComponentsManager: if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" - + output += " IP-Adapter: Enabled\n" + return output def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): @@ -821,13 +816,13 @@ class ComponentsManager: from ..pipelines.pipeline_utils import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): - + if component is None: continue - + # Add prefix if specified component_name = f"{prefix}_{name}" if prefix else name - + if component_name not in self.components: self.add(component_name, component) else: @@ -860,15 +855,15 @@ class ComponentsManager: if component_id not in self.components: raise ValueError(f"Component '{component_id}' not found in ComponentsManager") return self.components[component_id] - + results = self.get(name, collection, load_id) - + if not results: raise ValueError(f"No components found matching '{name}'") - + if len(results) > 1: raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: @@ -894,17 +889,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: if value_tuple not in value_to_keys: value_to_keys[value_tuple] = [] value_to_keys[value_tuple].append(key) - + def find_common_prefix(keys: List[str]) -> str: """Find the shortest common prefix among a list of dot-separated keys.""" if not keys: return "" if len(keys) == 1: return keys[0] - + # Split all keys into parts key_parts = [k.split('.') for k in keys] - + # Find how many initial parts are common common_length = 0 for parts in zip(*key_parts): @@ -912,10 +907,10 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: common_length += 1 else: break - + if common_length == 0: return "" - + # Return the common prefix return '.'.join(key_parts[0][:common_length]) @@ -929,5 +924,5 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: summary[prefix] = value else: summary[""] = value # Use empty string if no common prefix - + return summary diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index cdb28519f4..0d7bec5a5c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -11,49 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect - - +import os import traceback import warnings from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional -from typing_extensions import Self from copy import deepcopy - +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm.auto import tqdm -import re -import os -import importlib - from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm +from typing_extensions import Self from ..configuration_utils import ConfigMixin, FrozenDict +from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj from ..utils import ( + PushToHubMixin, is_accelerate_available, logging, - PushToHubMixin, ) -from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, InputParam, + InsertableOrderedDict, OutputParam, format_components, format_configs, format_inputs_short, format_intermediates_short, make_doc_string, - InsertableOrderedDict ) -from .components_manager import ComponentsManager -from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code -from copy import deepcopy + if is_accelerate_available(): import accelerate @@ -118,7 +113,7 @@ class PipelineState: def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} - + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ Get all inputs with matching kwargs_type. @@ -165,7 +160,7 @@ class PipelineState: inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - + # Format input_kwargs and intermediate_kwargs input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) @@ -180,7 +175,7 @@ class PipelineState: ) -@dataclass +@dataclass class BlockState: """ Container for block state data with attribute access and formatted representation. @@ -192,11 +187,11 @@ class BlockState: def __getitem__(self, key: str): # allows block_state["foo"] return getattr(self, key, None) - + def __setitem__(self, key: str, value: Any): # allows block_state["foo"] = "bar" setattr(self, key, value) - + def as_dict(self): """ Convert BlockState to a dictionary. @@ -211,21 +206,21 @@ class BlockState: # Handle tensors directly if hasattr(v, "shape") and hasattr(v, "dtype"): return f"Tensor(dtype={v.dtype}, shape={v.shape})" - + # Handle lists of tensors elif isinstance(v, list): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"List[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle tuples of tensors elif isinstance(v, tuple): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle dicts with tensor values elif isinstance(v, dict): formatted_dict = {} @@ -238,7 +233,7 @@ class BlockState: else: formatted_dict[k] = repr(val) return formatted_dict - + # Default case return repr(v) @@ -261,7 +256,7 @@ class ModularPipelineBlocks(ConfigMixin): expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters - + @classmethod def from_pretrained( @@ -311,17 +306,17 @@ class ModularPipelineBlocks(ConfigMixin): def save_pretrained(self, save_directory, push_to_hub = False, **kwargs): # TODO: factor out this logic. cls_name = self.__class__.__name__ - - full_mod = type(self).__module__ + + full_mod = type(self).__module__ module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") - parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] + parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - + self.register_to_config(auto_map=auto_map) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) config = dict(self.config) self._internal_dict = FrozenDict(config) - + def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): """ create a ModularLoader, optionally accept modular_repo to load from hub. @@ -329,22 +324,22 @@ class ModularPipelineBlocks(ConfigMixin): loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) - + # Create deep copies to avoid modifying the original specs component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) # Create the loader with the updated specs specs = component_specs + config_specs - + loader = loader_class(specs=specs, pretrained_model_name_or_path=pretrained_model_name_or_path, component_manager=component_manager, collection=collection) modular_pipeline = ModularPipeline(blocks=self, loader=loader) return modular_pipeline class PipelineBlock(ModularPipelineBlocks): - + model_name = None - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -354,12 +349,12 @@ class PipelineBlock(ModularPipelineBlocks): @property def expected_components(self) -> List[ComponentSpec]: return [] - + @property def expected_configs(self) -> List[ConfigSpec]: return [] - + @property def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" @@ -394,7 +389,7 @@ class PipelineBlock(ModularPipelineBlocks): @property def required_inputs(self) -> List[str]: return self._get_required_inputs() - + def _get_required_intermediates_inputs(self): input_names = [] @@ -403,7 +398,7 @@ class PipelineBlock(ModularPipelineBlocks): input_names.append(input_param.name) return input_names - # YiYi TODO: maybe we do not need this, it is only used in docstring, + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -460,9 +455,9 @@ class PipelineBlock(ModularPipelineBlocks): @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -474,7 +469,7 @@ class PipelineBlock(ModularPipelineBlocks): def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: if input_param.name: @@ -514,14 +509,14 @@ class PipelineBlock(ModularPipelineBlocks): data[k] = v data[input_param.kwargs_type][k] = v return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) state.add_intermediate(output_param.name, param, output_param.kwargs_type) - + for input_param in self.intermediates_inputs: if hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) @@ -561,7 +556,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li """ combined_dict = {} # name -> InputParam value_sources = {} # name -> block_name - + for block_name, inputs in named_input_lists: for input_param in inputs: if input_param.name is None and input_param.kwargs_type is not None: @@ -570,8 +565,8 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li input_name = input_param.name if input_name in combined_dict: current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and + if (current_param.default is not None and + input_param.default is not None and current_param.default != input_param.default): warnings.warn( f"Multiple different default values found for input '{input_name}': " @@ -584,7 +579,7 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li else: combined_dict[input_name] = input_param value_sources[input_name] = block_name - + return list(combined_dict.values()) def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: @@ -599,12 +594,12 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: Combined list of unique OutputParam objects """ combined_dict = {} # name -> OutputParam - + for block_name, outputs in named_output_lists: for output_param in outputs: if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): combined_dict[output_param.name] = output_param - + return list(combined_dict.values()) @@ -630,7 +625,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last + # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img @@ -650,7 +645,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -687,8 +682,8 @@ class AutoPipelineBlocks(ModularPipelineBlocks): required_by_all.intersection_update(block_required) return list(required_by_all) - - # YiYi TODO: maybe we do not need this, it is only used in docstring, + + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -736,7 +731,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] @@ -779,24 +774,24 @@ class AutoPipelineBlocks(ModularPipelineBlocks): """ def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them if hasattr(block, 'blocks'): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + trigger_inputs = set(self.block_trigger_inputs) trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - + return trigger_inputs @property @@ -812,7 +807,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): else f"{class_name}(\n" ) - + if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -836,7 +831,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -860,7 +855,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -870,27 +865,27 @@ class AutoPipelineBlocks(ModularPipelineBlocks): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -905,15 +900,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): block_classes = [] block_names = [] - + @property def description(self): return "" - + @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def expected_components(self): @@ -944,7 +939,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): A new SequentialPipelineBlocks instance """ instance = cls() - + # Create instances if classes are provided blocks = InsertableOrderedDict() for name, block in blocks_dict.items(): @@ -952,12 +947,12 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): blocks[name] = block() else: blocks[name] = block - + instance.block_classes = [block.__class__ for block in blocks.values()] instance.block_names = list(blocks.keys()) instance.blocks = blocks return instance - + def __init__(self): blocks = InsertableOrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): @@ -975,10 +970,10 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - - # YiYi TODO: maybe we do not need this, it is only used in docstring, + + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -1007,7 +1002,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): @property def intermediates_inputs(self) -> List[str]: return self.get_intermediates_inputs() - + def get_intermediates_inputs(self): inputs = [] outputs = set() @@ -1025,7 +1020,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -1043,7 +1038,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: @@ -1063,7 +1058,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): logger.error(error_msg) raise return pipeline, state - + def _get_trigger_inputs(self): """ Returns a set of all unique trigger input values found in the blocks. @@ -1071,21 +1066,21 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): """ def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them if hasattr(block, 'blocks'): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + return fn_recursive_get_trigger(self.blocks) @property @@ -1097,7 +1092,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): active_triggers = set(trigger_inputs) def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - + # sequential(include loopsequential) or PipelineBlock if not hasattr(block, 'block_trigger_inputs'): if hasattr(block, 'blocks'): @@ -1114,7 +1109,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): if hasattr(block, 'outputs'): active_triggers.update(out.name for out in block.outputs) return result_blocks - + # auto else: # Find first block_trigger_input that matches any value in our active_triggers @@ -1125,12 +1120,12 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): this_block = block.trigger_to_block_map[trigger_input] matching_trigger = trigger_input break - + # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] matching_trigger = None - + if this_block is not None: # sequential/auto (keep traversing) if hasattr(this_block, 'blocks'): @@ -1144,13 +1139,13 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): active_triggers.update(out.name for out in this_block.outputs) return result_blocks - + all_blocks = OrderedDict() for block_name, block in self.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks - + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs @@ -1164,7 +1159,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" ) trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - + if trigger_inputs is None: if None in trigger_inputs_all: trigger_inputs = [None] @@ -1172,7 +1167,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): trigger_inputs = [trigger_inputs_all[0]] blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ @@ -1182,7 +1177,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): else f"{class_name}(\n" ) - + if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1206,7 +1201,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1230,7 +1225,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -1240,27 +1235,27 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, @@ -1276,7 +1271,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): model_name = None block_classes = [] block_names = [] - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -1285,7 +1280,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): @property def loop_expected_components(self) -> List[ComponentSpec]: return [] - + @property def loop_expected_configs(self) -> List[ConfigSpec]: return [] @@ -1365,8 +1360,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): @property def inputs(self): return self.get_inputs() - - + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs @property def intermediates_inputs(self): @@ -1392,7 +1387,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -1414,10 +1409,10 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - # YiYi TODO: maybe we do not need this, it is only used in docstring, + # YiYi TODO: maybe we do not need this, it is only used in docstring, # intermediate_inputs is by default required, unless you manually handle it inside the block @property def required_intermediates_inputs(self) -> List[str]: @@ -1441,7 +1436,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): if output.name not in set([output.name for output in combined_outputs]): combined_outputs.append(output) return combined_outputs - + # YiYi TODO: this need to be thought about more # copied from SequentialPipelineBlocks @property @@ -1454,7 +1449,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks - + @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. @@ -1485,15 +1480,15 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): logger.error(error_msg) raise return components, state - + def __call__(self, components, state: PipelineState) -> PipelineState: raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - - + + def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: if input_param.name: @@ -1533,7 +1528,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): data[k] = v data[input_param.kwargs_type][k] = v return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): @@ -1563,17 +1558,17 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, expected_configs=self.expected_configs ) - # modified from SequentialPipelineBlocks, - #(does not need trigger_inputs related part so removed them, + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, # do not need to support auto block for loop blocks) def __repr__(self): class_name = self.__class__.__name__ @@ -1597,7 +1592,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1605,10 +1600,10 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): - + # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description desc_lines = block.description.split('\n') indented_desc = desc_lines[0] @@ -1618,18 +1613,18 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): # Build the representation with conditional sections result = f"{header}\n{desc}" - + # Only add components section if it has content if components_str.strip(): result += f"\n\n{components_str}" - + # Only add configs section if it has content if configs_str.strip(): result += f"\n\n{configs_str}" - + # Always add blocks section result += f"\n\n{blocks_str})" - + return result @torch.compiler.disable @@ -1652,7 +1647,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): self._progress_bar_config = kwargs -# YiYi TODO: +# YiYi TODO: # 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config # 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader @@ -1696,29 +1691,29 @@ class ModularLoader(ConfigMixin, PushToHubMixin): if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue - + # check if it is the first time registration, i.e. calling from __init__ is_registered = hasattr(self, name) # make sure the component is created from ComponentSpec if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") if module is not None: # actual library and class name of the module library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config - # e.g. {"repo": "stabilityai/stable-diffusion-2-1", - # "type_hint": ("diffusers", "UNet2DConditionModel"), + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), # "subfolder": "unet", # "variant": None, # "revision": None} component_spec_dict = self._component_spec_to_dict(component_spec) - + else: # if module is None, e.g. self.register_components(unet=None) during __init__ - # we do not update the spec, + # we do not update the spec, # but we still need to update the modular_model_index.json config based oncomponent spec library, class_name = None, None component_spec_dict = self._component_spec_to_dict(component_spec) @@ -1732,7 +1727,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue - + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: @@ -1764,7 +1759,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): self._component_manager.add(name, module, self._collection) - + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], pretrained_model_name_or_path: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): """ @@ -1792,7 +1787,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): elif name in self._config_specs: self._config_specs[name].default = value - + register_components_dict = {} for name, component_spec in self._component_specs.items(): if component_spec.default_creation_method == "from_config": @@ -1801,7 +1796,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): component = None register_components_dict[name] = component self.register_components(**register_components_dict) - + default_configs = {} for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default @@ -1844,7 +1839,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): ): return torch.device(module._hf_hook.execution_device) return self.device - + @property def dtype(self) -> torch.dtype: @@ -1871,12 +1866,6 @@ class ModularLoader(ConfigMixin, PushToHubMixin): } def update(self, **kwargs): - """ - Update components and configs after instance creation. - - Args: - - """ """ Update components and configuration values after the loader has been instantiated. @@ -1917,7 +1906,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") ) ``` - """ + """ # extract component_specs_updates & config_specs_updates from `specs` passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} @@ -1926,7 +1915,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): @@ -1942,14 +1931,14 @@ class ModularLoader(ConfigMixin, PushToHubMixin): # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) self._component_specs[name] = new_component_spec - + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - + created_components = {} for name, component_spec in passed_component_specs.items(): if component_spec.default_creation_method == "from_pretrained": - raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + raise ValueError("ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") created_components[name] = component_spec.create() current_component_spec = self._component_specs[name] # warn if type changed @@ -1991,7 +1980,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): unknown_component_names = set([name for name in component_names if name not in self._component_specs]) if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - + components_to_register = {} for name in components_to_load: spec = self._component_specs[name] @@ -2011,7 +2000,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): components_to_register[name] = spec.load(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") - + # Register all components at once self.register_components(**components_to_register) @@ -2033,7 +2022,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): ) return True return False - + # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to def to(self, *args, **kwargs) -> Self: r""" @@ -2071,7 +2060,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ - from ..pipelines.pipeline_utils import _check_bnb_status, DiffusionPipeline + from ..pipelines.pipeline_utils import _check_bnb_status from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version @@ -2227,7 +2216,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): ) return self - # YiYi TODO: + # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): @@ -2241,11 +2230,11 @@ class ModularLoader(ConfigMixin, PushToHubMixin): config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -2265,7 +2254,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) - + @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ @@ -2432,33 +2421,33 @@ class ModularPipeline: else: raise ValueError(f"Output '{output}' is not a valid output type") - + def load_components(self, component_names: Optional[List[str]] = None, **kwargs): self.loader.load(component_names=component_names, **kwargs) - + def update_components(self, **kwargs): self.loader.update(**kwargs) - + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], trust_remote_code: Optional[bool] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) pipeline = blocks.init_pipeline(pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs) return pipeline - + def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs): self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - + + @property def doc(self): return self.blocks.doc - + def to(self, *args, **kwargs): self.loader.to(*args, **kwargs) return self - + @property def components(self): - return self.loader.components \ No newline at end of file + return self.loader.components diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index c83b2abf50..1b9874bb52 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -12,44 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal - -from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin +import re from collections import OrderedDict +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List, Literal, Optional, Type, Union + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils.import_utils import is_torch_available + if is_torch_available(): - import torch + pass class InsertableOrderedDict(OrderedDict): def insert(self, key, value, index): items = list(self.items()) - + # Remove key if it already exists to avoid duplicates items = [(k, v) for k, v in items if k != key] - + # Insert at the specified index items.insert(index, (key, value)) - + # Clear and update self self.clear() self.update(items) - + # Return self for method chaining return self - + def __repr__(self): if not self: return "InsertableOrderedDict()" - + items = [] for i, (key, value) in enumerate(self.items()): items.append(f"{i}: ({repr(key)}, {repr(value)})") - + return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])" @@ -85,24 +86,24 @@ class ComponentSpec: variant: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True}) default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - + + def __hash__(self): """Make ComponentSpec hashable, using load_id as the hash value.""" return hash((self.name, self.load_id, self.default_creation_method)) - + def __eq__(self, other): """Compare ComponentSpec objects based on name and load_id.""" if not isinstance(other, ComponentSpec): return False - return (self.name == other.name and - self.load_id == other.load_id and + return (self.name == other.name and + self.load_id == other.load_id and self.default_creation_method == other.default_creation_method) - + @classmethod def from_component(cls, name: str, component: Any) -> Any: """Create a ComponentSpec from a Component created by `create` or `load` method.""" - + if not hasattr(component, "_diffusers_load_id"): raise ValueError("Component is not created by `create` or `load` method") # throw a error if component is created with `create` method but not a subclass of ConfigMixin @@ -113,19 +114,19 @@ class ComponentSpec: "created with `ComponentSpec.load` method" "or created with `ComponentSpec.create` and a subclass of ConfigMixin" ) - + type_hint = component.__class__ default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" - + if isinstance(component, ConfigMixin) and default_creation_method == "from_config": config = component.config else: config = None - + load_spec = cls.decode_load_id(component._diffusers_load_id) - + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) - + @classmethod def loading_fields(cls) -> List[str]: """ @@ -133,8 +134,8 @@ class ComponentSpec: (i.e. those whose field.metadata["loading"] is True). """ return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - + + @property def load_id(self) -> str: """ @@ -144,7 +145,7 @@ class ComponentSpec: parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] return "|".join(p for p in parts if p) - + @classmethod def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: """ @@ -165,26 +166,26 @@ class ComponentSpec: If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating component not created with `load` method). """ - + # Get all loading fields in order loading_fields = cls.loading_fields() result = {f: None for f in loading_fields} if load_id == "null": return result - + # Split the load_id parts = load_id.split("|") - + # Map parts to loading fields by position for i, part in enumerate(parts): if i < len(loading_fields): # Convert "null" string back to None result[loading_fields[i]] = None if part == "null" else part - + return result - - + + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) # the config info is lost in the process @@ -194,11 +195,11 @@ class ComponentSpec: if self.type_hint is None or not isinstance(self.type_hint, type): raise ValueError( - f"`type_hint` is required when using from_config creation method." + "`type_hint` is required when using from_config creation method." ) - + config = config or self.config or {} - + if issubclass(self.type_hint, ConfigMixin): component = self.type_hint.from_config(config, **kwargs) else: @@ -211,17 +212,17 @@ class ComponentSpec: if k in signature_params: init_kwargs[k] = v component = self.type_hint(**init_kwargs) - + component._diffusers_load_id = "null" if hasattr(component, "config"): self.config = component.config - + return component - + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained def load(self, **kwargs) -> Any: """Load component using from_pretrained.""" - + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} # merge loading field value in the spec with user passed values to create load_kwargs @@ -229,8 +230,8 @@ class ComponentSpec: # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - + raise ValueError("`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + if self.type_hint is None: try: from diffusers import AutoModel @@ -244,17 +245,17 @@ class ComponentSpec: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: raise ValueError(f"Unable to load {self.name} using load method: {e}") - + self.repo = repo for k, v in load_kwargs.items(): setattr(self, k, v) component._diffusers_load_id = self.load_id - + return component - -@dataclass + +@dataclass class ConfigSpec: """Specification for a pipeline configuration parameter.""" name: str @@ -281,7 +282,7 @@ class InputParam: return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" -@dataclass +@dataclass class OutputParam: """Specification for an output parameter.""" name: str @@ -315,14 +316,14 @@ def format_inputs_short(inputs): """ required_inputs = [param for param in inputs if param.required] optional_inputs = [param for param in inputs if not param.required] - + required_str = ", ".join(param.name for param in required_inputs) optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - + inputs_str = required_str if optional_str: inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - + return inputs_str @@ -353,18 +354,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu else: inp_name = inp.name input_parts.append(inp_name) - + # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} modified_parts = [] new_output_parts = [] - + for out in intermediates_outputs: if out.name in inputs_set: modified_parts.append(out.name) else: new_output_parts.append(out.name) - + result = [] if input_parts: result.append(f" - inputs: {', '.join(input_parts)}") @@ -372,7 +373,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu result.append(f" - modified: {', '.join(modified_parts)}") if new_output_parts: result.append(f" - outputs: {', '.join(new_output_parts)}") - + return "\n".join(result) if result else " (none)" @@ -390,18 +391,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): """ if not params: return "" - + base_indent = " " * indent_level param_indent = " " * (indent_level + 4) desc_indent = " " * (indent_level + 8) formatted_params = [] - + def get_type_str(type_hint): if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] return f"Union[{', '.join(types)}]" return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - + def wrap_text(text, indent, max_length): """Wrap text while preserving markdown links and maintaining indentation.""" words = text.split() @@ -411,7 +412,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): for word in words: word_length = len(word) + (1 if current_line else 0) - + if current_line and current_length + word_length > max_length: lines.append(" ".join(current_line)) current_line = [word] @@ -419,22 +420,22 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): else: current_line.append(word) current_length += word_length - + if current_line: lines.append(" ".join(current_line)) - + return f"\n{indent}".join(lines) - + # Add the header formatted_params.append(f"{base_indent}{header}:") - + for param in params: # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" # YiYi Notes: remove this line if we remove kwargs_type name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name param_str = f"{param_indent}{name} (`{type_str}`" - + # Add optional tag and default value if parameter is an InputParam and optional if hasattr(param, "required"): if not param.required: @@ -442,7 +443,7 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): if param.default is not None: param_str += f", defaults to {param.default}" param_str += "):" - + # Add description on a new line with additional indentation and wrapping if param.description: desc = re.sub( @@ -452,9 +453,9 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): ) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" - + formatted_params.append(param_str) - + return "\n\n".join(formatted_params) @@ -500,42 +501,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty """ if not components: return "" - + base_indent = " " * indent_level component_indent = " " * (indent_level + 4) formatted_components = [] - + # Add the header formatted_components.append(f"{base_indent}Components:") if add_empty_lines: formatted_components.append("") - + # Add each component with optional empty lines between them for i, component in enumerate(components): # Get type name, handling special cases type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - + component_desc = f"{component_indent}{component.name} (`{type_name}`)" if component.description: component_desc += f": {component.description}" - + # Get the loading fields dynamically loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) if field_value is not None: loading_field_values.append(f"{field_name}={field_value}") - + # Add loading field information if available if loading_field_values: component_desc += f" [{', '.join(loading_field_values)}]" - + formatted_components.append(component_desc) - + # Add an empty line after each component except the last one if add_empty_lines and i < len(components) - 1: formatted_components.append("") - + return "\n".join(formatted_components) @@ -553,27 +554,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines """ if not configs: return "" - + base_indent = " " * indent_level config_indent = " " * (indent_level + 4) formatted_configs = [] - + # Add the header formatted_configs.append(f"{base_indent}Configs:") if add_empty_lines: formatted_configs.append("") - + # Add each config with optional empty lines between them for i, config in enumerate(configs): config_desc = f"{config_indent}{config.name} (default: {config.default})" if config.description: config_desc += f": {config.description}" formatted_configs.append(config_desc) - + # Add an empty line after each config except the last one if add_empty_lines and i < len(configs) - 1: formatted_configs.append("") - + return "\n".join(formatted_configs) @@ -618,9 +619,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class # Add inputs section output += format_input_params(inputs + intermediates_inputs, indent_level=2) - + # Add outputs section output += "\n\n" output += format_output_params(outputs, indent_level=2) - return output \ No newline at end of file + return output diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py index 5f5e1c6c78..4855a9bcfc 100644 --- a/src/diffusers/modular_pipelines/node_utils.py +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -1,16 +1,19 @@ -from ..configuration_utils import ConfigMixin -from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks -from .modular_pipeline_utils import InputParam, OutputParam -from ..image_processor import PipelineImageInput -from pathlib import Path import json -import os - -from typing import Union, List, Optional, Tuple -import torch -import PIL -import numpy as np import logging +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch + +from ..configuration_utils import ConfigMixin +from ..image_processor import PipelineImageInput +from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks +from .modular_pipeline_utils import InputParam + + logger = logging.getLogger(__name__) # YiYi Notes: this is actually for SDXL, put it here for now @@ -189,8 +192,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): if group_key in name: return group_name return None - - + + class ModularNode(ConfigMixin): config_name = "node_config.json" @@ -214,15 +217,15 @@ class ModularNode(ConfigMixin): self.name_mapping = {} input_params = {} - # pass or create a default param dict for each input + # pass or create a default param dict for each input # e.g. for prompt, # prompt = { # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers - # "label": "Prompt", - # "type": "string", - # "default": "a bear sitting in a chair drinking a milkshake", - # "display": "textarea"} - # if type is not specified, it'll be a "custom" param of its own type + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} @@ -236,10 +239,10 @@ class ModularNode(ConfigMixin): if mellon_name != inp.name: self.name_mapping[inp.name] = mellon_name continue - - if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + + if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): continue - + if inp.name in DEFAULT_PARAM_MAPS: # first check if it's in the default param map, if so, directly use that param = DEFAULT_PARAM_MAPS[inp.name].copy() @@ -248,7 +251,7 @@ class ModularNode(ConfigMixin): if inp.name not in self.name_mapping: self.name_mapping[inp.name] = param else: - # if not, check if it's in the SDXL input schema, if so, + # if not, check if it's in the SDXL input schema, if so, # 1. use the type hint to determine the type # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} if inp.type_hint is not None: @@ -285,7 +288,7 @@ class ModularNode(ConfigMixin): break if to_exclude: continue - + if get_group_name(comp.name): param = get_group_name(comp.name) if comp.name not in self.name_mapping: @@ -303,7 +306,7 @@ class ModularNode(ConfigMixin): outputs = self.blocks.blocks[last_block_name].intermediates_outputs else: outputs = self.blocks.intermediates_outputs - + for out in outputs: param = kwargs.pop(out.name, None) if param: @@ -326,10 +329,10 @@ class ModularNode(ConfigMixin): param = out.name # add the param dict to the outputs dict output_params[out.name] = param - + if len(kwargs) > 0: logger.warning(f"Unused kwargs: {kwargs}") - + register_dict = { "category": category, "label": label, @@ -339,7 +342,7 @@ class ModularNode(ConfigMixin): "name_mapping": self.name_mapping, } self.register_to_config(**register_dict) - + def setup(self, components, collection=None): self.blocks.setup_loader(component_manager=components, collection=collection) self._components_manager = components @@ -347,7 +350,7 @@ class ModularNode(ConfigMixin): @property def mellon_config(self): return self._convert_to_mellon_config() - + def _convert_to_mellon_config(self): node = {} @@ -368,13 +371,13 @@ class ModularNode(ConfigMixin): } else: param = inp_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") - + for comp_name, comp_param in self.config.component_params.items(): if comp_name in self.name_mapping: mellon_name = self.name_mapping[comp_name] @@ -388,13 +391,13 @@ class ModularNode(ConfigMixin): } else: param = comp_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") - + for out_name, out_param in self.config.output_params.items(): if out_name in self.name_mapping: mellon_name = self.name_mapping[out_name] @@ -408,7 +411,7 @@ class ModularNode(ConfigMixin): } else: param = out_param - + if mellon_name not in node_param: node_param[mellon_name] = param else: @@ -427,22 +430,22 @@ class ModularNode(ConfigMixin): Path: Path to the saved config file """ file_path = Path(file_path) - + # Create directory if it doesn't exist os.makedirs(file_path.parent, exist_ok=True) - + # Create a combined dictionary with module definition and name mapping config = { "module": self.mellon_config, "name_mapping": self.name_mapping } - + # Save the config to file with open(file_path, 'w', encoding='utf-8') as f: json.dump(config, f, indent=2) - + logger.info(f"Mellon config and name mapping saved to {file_path}") - + return file_path @classmethod @@ -457,16 +460,16 @@ class ModularNode(ConfigMixin): dict: The loaded combined configuration containing 'module' and 'name_mapping' """ file_path = Path(file_path) - + if not file_path.exists(): raise FileNotFoundError(f"Config file not found: {file_path}") - + with open(file_path, 'r', encoding='utf-8') as f: config = json.load(f) - + logger.info(f"Mellon config loaded from {file_path}") - - + + return config def process_inputs(self, **kwargs): @@ -483,7 +486,7 @@ class ModularNode(ConfigMixin): if comp: params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) - + params_run = {} for inp_name, inp_param in self.config.input_params.items(): logger.debug(f"input: {inp_name}") @@ -495,14 +498,14 @@ class ModularNode(ConfigMixin): inp = kwargs.pop(mellon_inp_name) if inp is not None: params_run[inp_name] = inp - + return_output_names = list(self.config.output_params.keys()) return params_components, params_run, return_output_names def execute(self, **kwargs): params_components, params_run, return_output_names = self.process_inputs(**kwargs) - + self.blocks.loader.update(**params_components) output = self.blocks.run(**params_run, output=return_output_names) return output diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 1fbc141ac3..2fe15bbbee 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -34,11 +34,24 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modular_pipeline_presets import StableDiffusionXLAutoPipeline - from .modular_loader import StableDiffusionXLModularLoader - from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep from .decoders import StableDiffusionXLAutoDecodeStep - from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS + from .encoders import ( + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLTextEncoderStep, + ) + from .modular_block_mappings import ( + AUTO_BLOCKS, + CONTROLNET_BLOCKS, + CONTROLNET_UNION_BLOCKS, + IMAGE2IMAGE_BLOCKS, + INPAINT_BLOCKS, + IP_ADAPTER_BLOCKS, + SDXL_SUPPORTED_BLOCKS, + TEXT2IMAGE_BLOCKS, + ) + from .modular_loader import StableDiffusionXLModularLoader + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index f6ff339675..2032a57dcf 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -13,32 +13,27 @@ # limitations under the License. import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import Any, List, Optional, Tuple, Union import PIL import torch -from collections import OrderedDict -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...utils import logging -from ...utils.torch_utils import randn_tensor, unwrap_module - -from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel from ...schedulers import EulerDiscreteScheduler -from ...configuration_utils import FrozenDict - -from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ( AutoPipelineBlocks, - ModularLoader, PipelineBlock, PipelineState, SequentialPipelineBlocks, ) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_loader import StableDiffusionXLModularLoader + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -237,7 +232,7 @@ class StableDiffusionXLInputStep(PipelineBlock): InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), ] - + @property def intermediates_outputs(self) -> List[str]: return [ @@ -250,7 +245,7 @@ class StableDiffusionXLInputStep(PipelineBlock): OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), ] - + def check_inputs(self, components, block_state): if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: @@ -270,13 +265,13 @@ class StableDiffusionXLInputStep(PipelineBlock): raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): raise ValueError("`negative_ip_adapter_embeds` must be a list") - + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: @@ -298,19 +293,19 @@ class StableDiffusionXLInputStep(PipelineBlock): # duplicate text embeddings for each generation per prompt, using mps friendly method block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - + if block_state.negative_prompt_embeds is not None: _, seq_len, _ = block_state.negative_prompt_embeds.shape block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - + if block_state.negative_pooled_prompt_embeds is not None: block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - + if block_state.ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) @@ -318,7 +313,7 @@ class StableDiffusionXLInputStep(PipelineBlock): if block_state.negative_ip_adapter_embeds is not None: for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - + self.add_block_state(state, block_state) return components, state @@ -356,14 +351,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), ] @property def intermediates_outputs(self) -> List[str]: return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] @@ -455,7 +450,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): return [ ComponentSpec("scheduler", EulerDiscreteScheduler), ] - + @property def description(self) -> str: return ( @@ -473,7 +468,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] @@ -524,7 +519,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), InputParam( - "strength", + "strength", default=0.9999, description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " @@ -540,46 +535,46 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): return [ InputParam("generator"), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, + "latent_timestep", + required=True, + type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), + ), InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, + "image_latents", + required=True, + type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "mask", - required=True, - type_hint=torch.Tensor, + "mask", + required=True, + type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "masked_image_latents", - type_hint=torch.Tensor, + "masked_image_latents", + type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] @@ -587,13 +582,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): # YiYi TODO: update the _encode_vae_image so that we can use #Coped from @staticmethod def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -619,7 +614,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents + return image_latents # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument def prepare_latents_inpaint( @@ -737,15 +732,15 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype block_state.device = components._execution_device - + block_state.is_strength_max = block_state.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents @@ -822,9 +817,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] @property @@ -886,14 +881,14 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): return [ InputParam("generator"), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @@ -902,8 +897,8 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( - "latents", - type_hint=torch.Tensor, + "latents", + type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] @@ -980,7 +975,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("requires_aesthetics_score", False),] @@ -1008,15 +1003,15 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -1183,29 +1178,29 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), + ), InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -1344,26 +1339,26 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], + "crops_coords", + type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), ] @@ -1395,12 +1390,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - + image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size @@ -1416,9 +1411,9 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - + block_state = self.get_block_state(state) - + # (1) prepare controlnet inputs block_state.device = components._execution_device block_state.height, block_state.width = block_state.latents.shape[-2:] @@ -1446,14 +1441,14 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) # (1.3) - # global_pool_conditions + # global_pool_conditions block_state.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) # (1.4) - # guess_mode + # guess_mode block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions # (1.5) @@ -1501,12 +1496,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - + block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - + self.add_block_state(state, block_state) return components, state @@ -1542,32 +1537,32 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "dtype", - required=True, - type_hint=torch.dtype, + "dtype", + required=True, + type_hint=torch.dtype, description="The dtype of model tensor inputs. Can be generated in input step." - ), + ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." ), InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], + "crops_coords", + type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), ] @@ -1599,12 +1594,12 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - + image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size @@ -1618,7 +1613,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - + block_state = self.get_block_state(state) controlnet = unwrap_module(components.controlnet) @@ -1651,7 +1646,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): if len(block_state.control_image) != len(block_state.control_mode): raise ValueError("Expected len(control_image) == len(control_type)") - # control_type + # control_type block_state.num_control_type = controlnet.config.num_control_type block_state.control_type = [0 for _ in range(block_state.num_control_type)] for control_idx in block_state.control_mode: @@ -1676,7 +1671,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): crops_coords=block_state.crops_coords, ) block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - + # controlnet_keep block_state.controlnet_keep = [] for i in range(len(block_state.timesteps)): @@ -1687,7 +1682,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): block_state.control_type_idx = block_state.control_mode block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - + self.add_block_state(state, block_state) return components, state @@ -1698,7 +1693,7 @@ class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] block_names = ["controlnet_union", "controlnet"] block_trigger_inputs = ["control_mode", "control_image"] - + @property def description(self): return "Controlnet Input step that prepare the controlnet input.\n" + \ diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index ca848e2098..3a4e141775 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -12,29 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import Any, List, Tuple, Union +import numpy as np import PIL import torch -import numpy as np -from collections import OrderedDict -from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...utils import logging - -from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput -from ...configuration_utils import FrozenDict - -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline import ( AutoPipelineBlocks, PipelineBlock, PipelineState, SequentialPipelineBlocks, ) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,15 +40,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionXLDecodeStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -160,10 +156,10 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("image", required=True), - InputParam("mask_image", required=True), + InputParam("mask_image", required=True), InputParam("padding_mask_crop"), ] - + @property def intermediates_inputs(self) -> List[str]: return [ diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 3a8bca74b5..5646651100 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -13,28 +13,25 @@ # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple import torch -from tqdm.auto import tqdm from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import ControlNetModel, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging -from ...utils.torch_utils import unwrap_module - -from ...guiders import ClassifierFreeGuidance -from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, PipelineBlock, PipelineState, - AutoPipelineBlocks, - LoopSequentialPipelineBlocks, - BlockState, ) -from dataclasses import asdict +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_loader import StableDiffusionXLModularLoader + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -61,9 +58,9 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), ] @@ -96,19 +93,19 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), ] @@ -133,7 +130,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `components.unet` or your `mask_image` or `image` input." ) - + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): @@ -155,9 +152,9 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ] @@ -178,18 +175,18 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds, " @@ -202,10 +199,10 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): ] - + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) guider_input_fields ={ @@ -231,7 +228,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} prompt_embeds = cond_kwargs.pop("prompt_embeds") - + # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( @@ -259,9 +256,9 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), @@ -281,18 +278,18 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "controlnet_cond", + "controlnet_cond", required=True, type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "conditioning_scale", + "conditioning_scale", type_hint=float, description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "guess_mode", + "guess_mode", required=True, type_hint=bool, description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." @@ -304,18 +301,18 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - kwargs_type="guider_input_fields", + kwargs_type="guider_input_fields", description=( "All conditional model inputs that need to be prepared with guider. " "It should contain prompt_embeds/negative_prompt_embeds, " @@ -326,7 +323,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): ) ), InputParam( - kwargs_type="controlnet_kwargs", + kwargs_type="controlnet_kwargs", description=( "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" @@ -369,14 +366,14 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] - + # default controlnet output/unet input for guess mode + conditional path block_state.down_block_res_samples_zeros = None block_state.mid_block_res_sample_zeros = None - + # guided denoiser step components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - + # Prepare mini‐batches according to guidance method and `guider_input_fields` # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # e.g. for CFG, we prepare two batches: one for uncond, one for cond @@ -387,7 +384,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) - + # Prepare additional conditionings added_cond_kwargs = { "text_embeds": guider_state_batch.text_embeds, @@ -395,7 +392,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): } if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds - + # Prepare controlnet additional conditionings controlnet_added_cond_kwargs = { "text_embeds": guider_state_batch.text_embeds, @@ -418,13 +415,13 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): return_dict=False, **extra_controlnet_kwargs, ) - + # assign it to block_state so it will be available for the uncond guidance batch if block_state.down_block_res_samples_zeros is None: block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] if block_state.mid_block_res_sample_zeros is None: block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) - + # Predict the noise # store the noise_pred in guider_state_batch so we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( @@ -439,7 +436,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): return_dict=False, )[0] components.guider.cleanup_models(components.unet) - + # Perform guidance block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) @@ -475,7 +472,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - + #YiYi TODO: move this out of here @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): @@ -499,7 +496,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -534,24 +531,24 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): return [ InputParam("generator"), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), ] @@ -559,7 +556,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - + @staticmethod def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): @@ -570,7 +567,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): extra_kwargs[key] = value return extra_kwargs - + def check_inputs(self, components, block_state): if components.num_channels_unet == 4: if block_state.image_latents is None: @@ -582,9 +579,9 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - + self.check_inputs(components, block_state) - + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) @@ -592,12 +589,12 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - + if block_state.latents.dtype != block_state.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 block_state.latents = block_state.latents.to(block_state.latents_dtype) - + # adjust latent for inpainting if components.num_channels_unet == 4: block_state.init_latents_proper = block_state.image_latents @@ -629,32 +626,32 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): def loop_expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] - + @property def loop_intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), ] - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -782,619 +779,4 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." " - `StableDiffusionXLDenoiseStep` (denoise) is used when no controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (work for text2img, img2img and inpainting tasks)." - ) - - - - - - - -# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - - -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state \ No newline at end of file + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index ca4efe2c4a..a563ffbbbe 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple -import PIL import torch -from collections import OrderedDict +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...utils import ( USE_PEFT_BACKEND, @@ -30,26 +35,10 @@ from ...utils import ( scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel -from ...configuration_utils import FrozenDict - -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance - +from ..modular_pipeline import AutoPipelineBlocks, PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_loader import StableDiffusionXLModularLoader -from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,7 +60,7 @@ def retrieve_latents( class StableDiffusionXLIPAdapterStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -79,7 +68,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -87,8 +76,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec( - "guider", - ClassifierFreeGuidance, + "guider", + ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -97,8 +86,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): def inputs(self) -> List[InputParam]: return [ InputParam( - "ip_adapter_image", - PipelineImageInput, + "ip_adapter_image", + PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" ) @@ -111,7 +100,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") ] - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components @staticmethod def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): @@ -137,7 +126,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds - + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds @@ -219,7 +208,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): return( "Text Encoder step that generate text_embeddings to guide the image generation" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -228,9 +217,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -546,7 +535,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -558,9 +547,9 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -576,7 +565,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): def intermediates_inputs(self) -> List[InputParam]: return [ InputParam("generator"), - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property @@ -586,13 +575,13 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -618,8 +607,8 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents - + return image_latents + @torch.no_grad() @@ -628,7 +617,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} block_state.device = components._execution_device block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) @@ -651,23 +640,23 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ComponentSpec( - "mask_processor", - VaeImageProcessor, + "mask_processor", + VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), default_creation_method="from_config"), ] - + @property def description(self) -> str: @@ -694,21 +683,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -734,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents + return image_latents # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance @@ -784,8 +773,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: @@ -801,7 +790,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): else: block_state.crops_coords = None block_state.resize_mode = "default" - + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) block_state.image = block_state.image.to(dtype=torch.float32) @@ -834,7 +823,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): # auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) # Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] block_names = ["inpaint", "img2img"] block_trigger_inputs = ["mask_image", "image"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py index 4ffd685df0..9440d72319 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_block_mappings.py @@ -13,44 +13,40 @@ # limitations under the License. from ..modular_pipeline_utils import InsertableOrderedDict +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import StableDiffusionXLAutoDecodeStep, StableDiffusionXLDecodeStep, StableDiffusionXLInpaintDecodeStep # Import all the necessary block classes from .denoise import ( StableDiffusionXLAutoDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseLoop, - StableDiffusionXLInpaintDenoiseLoop -) -from .before_denoise import ( - StableDiffusionXLAutoBeforeDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLImg2ImgSetTimestepsStep, - StableDiffusionXLImg2ImgPrepareLatentsStep, - StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, - StableDiffusionXLInpaintPrepareLatentsStep, - StableDiffusionXLControlNetInputStep, - StableDiffusionXLControlNetUnionInputStep + StableDiffusionXLInpaintDenoiseLoop, ) from .encoders import ( - StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, - StableDiffusionXLVaeEncoderStep, StableDiffusionXLInpaintVaeEncoderStep, - StableDiffusionXLIPAdapterStep -) -from .decoders import ( - StableDiffusionXLDecodeStep, - StableDiffusionXLInpaintDecodeStep, - StableDiffusionXLAutoDecodeStep + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, ) # YiYi notes: comment out for now, work on this later -# block mapping +# block mapping TEXT2IMAGE_BLOCKS = InsertableOrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("input", StableDiffusionXLInputStep), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py index 4af942af64..0f567513c5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union + +import numpy as np import PIL import torch -import numpy as np -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...image_processor import PipelineImageInput +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ...utils import logging - from ..modular_pipeline import ModularLoader from ..modular_pipeline_utils import InputParam, OutputParam + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py index 637c7ac306..981f4d7e03 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union, Dict from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks - -from .denoise import StableDiffusionXLAutoDenoiseStep from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep from .decoders import StableDiffusionXLAutoDecodeStep -from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep +from .denoise import StableDiffusionXLAutoDenoiseStep +from .encoders import ( + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLTextEncoderStep, +) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 47aae71984..8eb99038c1 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,12 +15,12 @@ """Utilities to dynamically load objects from the Hub.""" import importlib -import signal import inspect import json import os import re import shutil +import signal import sys import threading from pathlib import Path @@ -531,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module) \ No newline at end of file + return get_class_in_module(class_name, final_module)