1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

support cfgpp in ddim

This commit is contained in:
Aryan
2025-04-16 13:23:24 +02:00
parent 77d8a285bf
commit 78fca12803
5 changed files with 38 additions and 34 deletions

View File

@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union, Tuple, List
from typing import Dict, Optional, Union, Tuple, List
import torch
@@ -84,15 +83,6 @@ class CFGPlusPlusGuidance(BaseGuidance):
return pred
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
if self._is_cfgpp_enabled():
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
pred_cond = self._preds["pred_cond"]
pred_uncond = self._preds["pred_uncond"]
diff = pred_uncond - pred_cond
pred = pred + diff * self.guidance_scale * self._sigma_next
return pred
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 0
@@ -104,6 +94,14 @@ class CFGPlusPlusGuidance(BaseGuidance):
num_conditions += 1
return num_conditions
@property
def outputs(self) -> Dict[str, torch.Tensor]:
scheduler_step_kwargs = {}
if self._is_cfgpp_enabled():
scheduler_step_kwargs["_use_cfgpp"] = True
scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond")
return self._preds, scheduler_step_kwargs
def _is_cfgpp_enabled(self) -> bool:
if not self._enabled:
return False

View File

@@ -37,8 +37,6 @@ class BaseGuidance:
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._sigma: torch.Tensor = None
self._sigma_next: torch.Tensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
self._enabled = True
@@ -63,12 +61,10 @@ class BaseGuidance:
def _force_enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None:
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._sigma = sigma
self._sigma_next = sigma_next
self._preds = {}
self._num_outputs_prepared = 0
@@ -95,9 +91,6 @@ class BaseGuidance:
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
return pred
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@@ -112,7 +105,7 @@ class BaseGuidance:
@property
def outputs(self) -> Dict[str, torch.Tensor]:
return self._preds
return self._preds, {}
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):

View File

@@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
(
latents,
@@ -2295,13 +2295,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
# Perform guidance
outputs = pipeline.guider.outputs
outputs, scheduler_step_kwargs = pipeline.guider.outputs
data.noise_pred = pipeline.guider(**outputs)
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
@@ -2638,7 +2637,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# (5) Denoise loop
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
(
latents,
@@ -2725,13 +2724,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
# Perform guidance
outputs = pipeline.guider.outputs
outputs, scheduler_step_kwargs = pipeline.guider.outputs
data.noise_pred = pipeline.guider(**outputs)
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():
@@ -3055,7 +3053,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
for i, t in enumerate(data.timesteps):
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
(
latents,
@@ -3144,13 +3142,12 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
# Perform guidance
outputs = pipeline.guider.outputs
outputs, scheduler_step_kwargs = pipeline.guider.outputs
data.noise_pred = pipeline.guider(**outputs)
# Perform scheduler step using the predicted output
data.latents_dtype = data.latents.dtype
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
data.latents = pipeline.guider.post_scheduler_step(data.latents)
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
if data.latents.dtype != data.latents_dtype:
if torch.backends.mps.is_available():

View File

@@ -349,6 +349,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
_model_output_uncond: Optional[torch.Tensor] = None,
_use_cfgpp: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -386,6 +388,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if _use_cfgpp and self.config.prediction_type != "epsilon":
raise ValueError(
f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}."
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
@@ -411,7 +418,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)

View File

@@ -584,6 +584,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
_model_output_uncond: Optional[torch.Tensor] = None,
_use_cfgpp: bool = False,
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -627,6 +629,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if _use_cfgpp and self.config.prediction_type != "epsilon":
raise ValueError(
f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}."
)
if self.step_index is None:
self._init_step_index(timestep)
@@ -668,7 +675,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
if _use_cfgpp:
prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1]
# denoised = sample - model_output * sigmas[i]
# d = (sample - denoised) / sigmas[i]
# new_sample = denoised + d * sigmas[i + 1]