From 720783e508c9f76ed4ecd05763db96a989a0ec86 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 21:13:13 +0200 Subject: [PATCH] smoothed energy guidance --- src/diffusers/__init__.py | 4 + src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/guider_utils.py | 4 +- src/diffusers/guiders/skip_layer_guidance.py | 5 +- .../guiders/smoothed_energy_guidance.py | 250 ++++++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_common.py | 11 + src/diffusers/hooks/layer_skip.py | 16 +- .../hooks/smoothed_energy_guidance_utils.py | 148 +++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 12 +- 10 files changed, 431 insertions(+), 21 deletions(-) create mode 100644 src/diffusers/guiders/smoothed_energy_guidance.py create mode 100644 src/diffusers/hooks/smoothed_energy_guidance_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0672e10356..67d8ae9f79 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -137,6 +137,7 @@ else: "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", + "SmoothedEnergyGuidance", ] ) _import_structure["hooks"].extend( @@ -145,6 +146,7 @@ else: "HookRegistry", "PyramidAttentionBroadcastConfig", "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", @@ -729,12 +731,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, + SmoothedEnergyGuidance, ) from .hooks import ( FasterCacheConfig, HookRegistry, LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index c23e48578e..7b88d61c67 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -23,5 +23,6 @@ if is_torch_available(): from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 60859bf390..e7d22a50d0 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -55,10 +55,10 @@ class BaseGuidance: "`_input_predictions` must be a list of required prediction names for the guidance technique." ) - def force_disable(self): + def _force_disable(self): self._enabled = False - def force_enable(self): + def _force_enable(self): self._enabled = True def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 3fbfd771ef..64b2b8a73c 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -24,8 +24,9 @@ from .guider_utils import BaseGuidance, rescale_noise_cfg class SkipLayerGuidance(BaseGuidance): """ - Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): - https://huggingface.co/papers/2411.18664 + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 0000000000..bd2a61b894 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,250 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_seg_enabled(): + # If we're predicting pred_cond and pred_cond_seg only, we need to set the key to pred_cond_seg + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_seg" + self._preds[key] = pred + + if key == "pred_cond_seg": + # If we are in SLG mode, we need to remove the hooks after inference + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 142ff86037..9d0e96e9e7 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -8,3 +8,4 @@ if is_torch_available(): from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 6ea83dcbf6..3d9c99e818 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import torch + from ..models.attention import FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention @@ -30,3 +34,10 @@ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, } ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 14b1cf492d..45f42a1f0f 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -20,7 +20,7 @@ import torch from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -66,12 +66,13 @@ class LayerSkipConfig: def __post_init__(self): if not (0 <= self.dropout <= 1): raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): - def __init__(self) -> None: - super().__init__() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -226,10 +227,3 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam f"Could not find any transformer blocks matching the provided indices {config.indices} and " f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." ) - - -def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: - for submodule_name, submodule in module.named_modules(): - if submodule_name == fqn: - return submodule - return None diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 0000000000..20df0de048 --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query[:] = query.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice + + return query diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index aed212c3f8..76030a7153 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2231,9 +2231,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2626,9 +2626,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # (2) Prepare conditional inputs for unet using the guider data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -3039,9 +3039,9 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0]