mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
cfg plus plus
This commit is contained in:
@@ -134,6 +134,7 @@ else:
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AutoGuidance",
|
||||
"CFGPlusPlusGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"SkipLayerGuidance",
|
||||
@@ -729,6 +730,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
CFGPlusPlusGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
SkipLayerGuidance,
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..utils import is_torch_available
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
|
||||
117
src/diffusers/guiders/classifier_free_guidance_plus_plus.py
Normal file
117
src/diffusers/guiders/classifier_free_guidance_plus_plus.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs
|
||||
|
||||
|
||||
class CFGPlusPlusGuidance(BaseGuidance):
|
||||
"""
|
||||
CFG++: https://huggingface.co/papers/2406.08070
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `0.7`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 0.7,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
|
||||
|
||||
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
self._preds[key] = pred
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfgpp_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
if self._is_cfgpp_enabled():
|
||||
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
|
||||
pred_cond = self._preds["pred_cond"]
|
||||
pred_uncond = self._preds["pred_uncond"]
|
||||
diff = pred_uncond - pred_cond
|
||||
pred = pred + diff * self.guidance_scale * self._sigma_next
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 0
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfgpp_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfgpp_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
return is_within_range
|
||||
@@ -37,6 +37,8 @@ class BaseGuidance:
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._sigma: torch.Tensor = None
|
||||
self._sigma_next: torch.Tensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
self._enabled = True
|
||||
@@ -61,10 +63,12 @@ class BaseGuidance:
|
||||
def _force_enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._sigma = sigma
|
||||
self._sigma_next = sigma_next
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
@@ -91,6 +95,9 @@ class BaseGuidance:
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@@ -58,7 +58,6 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
return _default_prepare_inputs(denoiser, self.num_conditions, *args)
|
||||
|
||||
@@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -2301,6 +2301,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -2637,7 +2638,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# (5) Denoise loop
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -2730,6 +2731,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -3053,7 +3055,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -3148,6 +3150,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
@@ -669,6 +669,35 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
# denoised = sample - model_output * sigmas[i]
|
||||
# d = (sample - denoised) / sigmas[i]
|
||||
# new_sample = denoised + d * sigmas[i + 1]
|
||||
|
||||
# new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i]
|
||||
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1]
|
||||
# new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i])
|
||||
# new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1)
|
||||
|
||||
# CFG++ =====
|
||||
# denoised = sample - model_output * sigmas[i]
|
||||
# uncond_denoised = sample - model_output_uncond * sigmas[i]
|
||||
# d = (sample - uncond_denoised) / sigmas[i]
|
||||
# new_sample = denoised + d * sigmas[i + 1]
|
||||
|
||||
# new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i]
|
||||
# new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2)
|
||||
|
||||
# To go from (1) to (2):
|
||||
# new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1]
|
||||
# new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1]
|
||||
# new_sample_2 = new_sample_1 + diff * sigmas[i + 1]
|
||||
|
||||
# diff = model_output_uncond - model_output
|
||||
# diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond))
|
||||
# diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond)
|
||||
# diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond
|
||||
# diff = g * (model_output_uncond - model_output_cond)
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user