From 2dc673a213dba107aa7463a73295421ff1d30218 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 06:38:03 +0200 Subject: [PATCH] tangential cfg --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 3 +- .../guiders/adaptive_projected_guidance.py | 7 +- .../guiders/smoothed_energy_guidance.py | 5 +- .../tangential_classifier_free_guidance.py | 133 ++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 2 +- 6 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/guiders/tangential_classifier_free_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 67d8ae9f79..a4f55acf8b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -138,6 +138,7 @@ else: "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", ] ) _import_structure["hooks"].extend( @@ -732,6 +733,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, ) from .hooks import ( FasterCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 7b88d61c67..3c1ee29338 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -24,5 +24,6 @@ if is_torch_available(): from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 45bd196860..05c186e58d 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -155,20 +155,25 @@ def normalized_guidance( ): diff = pred_cond - pred_uncond dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average + if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond - pred = pred + (guidance_scale - 1) * normalized_update + pred = pred + guidance_scale * normalized_update + return pred diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 2328aa82ec..906900856f 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -27,7 +27,10 @@ class SmoothedEnergyGuidance(BaseGuidance): Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 SEG is only supported as an experimental prototype feature for now, so the implementation may be modified - in the future without warning or guarantee of reproducibility. + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 0000000000..078d795baa --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,133 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 45f42a1f0f..c50d2b7471 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -92,7 +92,7 @@ class AttentionProcessorSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: - if math.isclose(self.dropout, 1.0): + if not math.isclose(self.dropout, 1.0): raise ValueError( "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." )