From 77d8a285bf36960a2e0315725e322e2f0f1f6197 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 08:08:00 +0200 Subject: [PATCH] cfg plus plus --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../classifier_free_guidance_plus_plus.py | 117 ++++++++++++++++++ src/diffusers/guiders/guider_utils.py | 9 +- .../tangential_classifier_free_guidance.py | 1 - .../pipeline_stable_diffusion_xl_modular.py | 9 +- .../schedulers/scheduling_euler_discrete.py | 29 +++++ 7 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/guiders/classifier_free_guidance_plus_plus.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4f55acf8b..424011961a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,6 +134,7 @@ else: [ "AdaptiveProjectedGuidance", "AutoGuidance", + "CFGPlusPlusGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -729,6 +730,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .guiders import ( AdaptiveProjectedGuidance, AutoGuidance, + CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 3c1ee29338..56e95c92b6 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,6 +20,7 @@ from ..utils import is_torch_available if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .auto_guidance import AutoGuidance + from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py new file mode 100644 index 0000000000..516dbfa0e0 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class CFGPlusPlusGuidance(BaseGuidance): + """ + CFG++: https://huggingface.co/papers/2406.08070 + + Args: + guidance_scale (`float`, defaults to `0.7`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 0.7, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfgpp_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + if self._is_cfgpp_enabled(): + # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! + pred_cond = self._preds["pred_cond"] + pred_uncond = self._preds["pred_uncond"] + diff = pred_uncond - pred_cond + pred = pred + diff * self.guidance_scale * self._sigma_next + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfgpp_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfgpp_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + return is_within_range diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index e7d22a50d0..f51452ed0c 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,6 +37,8 @@ class BaseGuidance: self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None + self._sigma: torch.Tensor = None + self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -61,10 +63,12 @@ class BaseGuidance: def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep + self._sigma = sigma + self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -91,6 +95,9 @@ class BaseGuidance: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + return pred + @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 078d795baa..7529114bfd 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -58,7 +58,6 @@ class TangentialClassifierFreeGuidance(BaseGuidance): self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - self.momentum_buffer = None def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: return _default_prepare_inputs(denoiser, self.num_conditions, *args) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 76030a7153..8e0ea4545f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2301,6 +2301,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2637,7 +2638,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2730,6 +2731,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3053,7 +3055,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -3148,6 +3150,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 56757f3ca1..4adec768b7 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -669,6 +669,35 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): prev_sample = sample + derivative * dt + # denoised = sample - model_output * sigmas[i] + # d = (sample - denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] + # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i]) + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1) + + # CFG++ ===== + # denoised = sample - model_output * sigmas[i] + # uncond_denoised = sample - model_output_uncond * sigmas[i] + # d = (sample - uncond_denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2) + + # To go from (1) to (2): + # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1] + # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1] + # new_sample_2 = new_sample_1 + diff * sigmas[i + 1] + + # diff = model_output_uncond - model_output + # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond)) + # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond) + # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond + # diff = g * (model_output_uncond - model_output_cond) + # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype)