1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

cfg plus plus

This commit is contained in:
Aryan
2025-04-16 08:08:00 +02:00
parent 2dc673a213
commit 77d8a285bf
7 changed files with 163 additions and 5 deletions

View File

@@ -134,6 +134,7 @@ else:
[
"AdaptiveProjectedGuidance",
"AutoGuidance",
"CFGPlusPlusGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"SkipLayerGuidance",
@@ -729,6 +730,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .guiders import (
AdaptiveProjectedGuidance,
AutoGuidance,
CFGPlusPlusGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
SkipLayerGuidance,

View File

@@ -20,6 +20,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .skip_layer_guidance import SkipLayerGuidance

View File

@@ -0,0 +1,117 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union, Tuple, List
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
class CFGPlusPlusGuidance(BaseGuidance):
"""
CFG++: https://huggingface.co/papers/2406.08070
Args:
guidance_scale (`float`, defaults to `0.7`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 0.7,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfgpp_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
if self._is_cfgpp_enabled():
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
pred_cond = self._preds["pred_cond"]
pred_uncond = self._preds["pred_uncond"]
diff = pred_uncond - pred_cond
pred = pred + diff * self.guidance_scale * self._sigma_next
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfgpp_enabled():
num_conditions += 1
return num_conditions
def _is_cfgpp_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
return is_within_range

View File

@@ -37,6 +37,8 @@ class BaseGuidance:
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._sigma: torch.Tensor = None
self._sigma_next: torch.Tensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
self._enabled = True
@@ -61,10 +63,12 @@ class BaseGuidance:
def _force_enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._sigma = sigma
self._sigma_next = sigma_next
self._preds = {}
self._num_outputs_prepared = 0
@@ -91,6 +95,9 @@ class BaseGuidance:
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
return pred
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")

View File

@@ -58,7 +58,6 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
return _default_prepare_inputs(denoiser, self.num_conditions, *args)

View File

@@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
(
latents,
@@ -2301,6 +2301,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
@@ -2637,7 +2638,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# (5) Denoise loop
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
(
latents,
@@ -2730,6 +2731,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
@@ -3053,7 +3055,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
(
latents,
@@ -3148,6 +3150,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():

View File

@@ -669,6 +669,35 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
# denoised = sample - model_output * sigmas[i]
# d = (sample - denoised) / sigmas[i]
# new_sample = denoised + d * sigmas[i + 1]
# new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i]
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1]
# new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i])
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1)
# CFG++ =====
# denoised = sample - model_output * sigmas[i]
# uncond_denoised = sample - model_output_uncond * sigmas[i]
# d = (sample - uncond_denoised) / sigmas[i]
# new_sample = denoised + d * sigmas[i + 1]
# new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i]
# new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2)
# To go from (1) to (2):
# new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1]
# new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1]
# new_sample_2 = new_sample_1 + diff * sigmas[i + 1]
# diff = model_output_uncond - model_output
# diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond))
# diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond)
# diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond
# diff = g * (model_output_uncond - model_output_cond)
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)