From 05d74ef3e7e1bc09888145dd27a3b82844280189 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 04:21:24 +0200 Subject: [PATCH] cfg zero star --- src/diffusers/__init__.py | 6 +- src/diffusers/guiders/__init__.py | 1 + .../guiders/classifier_free_guidance.py | 13 ++- .../classifier_free_zero_star_guidance.py | 100 ++++++++++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 3 - src/diffusers/utils/dummy_pt_objects.py | 15 +++ 6 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 91c41bdd43..e0d629087e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -130,7 +130,9 @@ except OptionalDependencyNotAvailable: _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"]) + _import_structure["guiders"].extend( + ["ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", @@ -714,7 +716,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .guiders import ClassifierFreeGuidance, SkipLayerGuidance + from .guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 9724d30756..af6c961e23 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -17,5 +17,6 @@ from ..utils import is_torch_available if is_torch_available(): from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 18f2a2d31b..96ac875db5 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -64,13 +64,18 @@ class ClassifierFreeGuidance(GuidanceMixin): self.use_original_formulation = use_original_formulation def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + if math.isclose(self.guidance_scale, 1.0): - return pred_cond - shift = pred_cond - pred_uncond - pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.guidance_scale * shift + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + return pred @property diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 0000000000..518b108554 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,100 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class ClassifierFreeZeroStarGuidance(GuidanceMixin): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif math.isclose(self.guidance_scale, 1.0): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + return num_conditions + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(cond.dtype) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 50f864331a..120b0d632b 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -175,7 +175,6 @@ class SkipLayerGuidance(GuidanceMixin): if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0): pred = pred_cond - elif math.isclose(self.guidance_scale, 1.0): if skip_start_step < self._step < skip_stop_step: shift = pred_cond - pred_cond_skip @@ -183,12 +182,10 @@ class SkipLayerGuidance(GuidanceMixin): pred = pred + self.skip_layer_guidance_scale * shift else: pred = pred_cond - elif math.isclose(self.skip_layer_guidance_scale, 1.0): shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond pred = pred + self.guidance_scale * shift - else: shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3c0f45461b..9e9e2cdfbb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ class ClassifierFreeGuidance(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SkipLayerGuidance(metaclass=DummyObject): _backends = ["torch"]