From 78fca12803d69541fc63161b036db96792520fd8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 13:23:24 +0200 Subject: [PATCH] support cfgpp in ddim --- .../classifier_free_guidance_plus_plus.py | 20 ++++++++---------- src/diffusers/guiders/guider_utils.py | 11 ++-------- .../pipeline_stable_diffusion_xl_modular.py | 21 ++++++++----------- src/diffusers/schedulers/scheduling_ddim.py | 9 +++++++- .../schedulers/scheduling_euler_discrete.py | 11 +++++++++- 5 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py index 516dbfa0e0..d1c6f87441 100644 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Union, Tuple, List +from typing import Dict, Optional, Union, Tuple, List import torch @@ -84,15 +83,6 @@ class CFGPlusPlusGuidance(BaseGuidance): return pred - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - if self._is_cfgpp_enabled(): - # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! - pred_cond = self._preds["pred_cond"] - pred_uncond = self._preds["pred_uncond"] - diff = pred_uncond - pred_cond - pred = pred + diff * self.guidance_scale * self._sigma_next - return pred - @property def is_conditional(self) -> bool: return self._num_outputs_prepared == 0 @@ -104,6 +94,14 @@ class CFGPlusPlusGuidance(BaseGuidance): num_conditions += 1 return num_conditions + @property + def outputs(self) -> Dict[str, torch.Tensor]: + scheduler_step_kwargs = {} + if self._is_cfgpp_enabled(): + scheduler_step_kwargs["_use_cfgpp"] = True + scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond") + return self._preds, scheduler_step_kwargs + def _is_cfgpp_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index f51452ed0c..420a566906 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,8 +37,6 @@ class BaseGuidance: self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None - self._sigma: torch.Tensor = None - self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -63,12 +61,10 @@ class BaseGuidance: def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep - self._sigma = sigma - self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -95,9 +91,6 @@ class BaseGuidance: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - return pred - @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") @@ -112,7 +105,7 @@ class BaseGuidance: @property def outputs(self) -> Dict[str, torch.Tensor]: - return self._preds + return self._preds, {} def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e0ea4545f..37d2bbbe6c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2295,13 +2295,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2638,7 +2637,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2725,13 +2724,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3055,7 +3053,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -3144,13 +3142,12 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 13c9b3b4a5..2e74c9bbfc 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -349,6 +349,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -386,6 +388,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -411,7 +418,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output + pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4adec768b7..4c82ca7e38 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -584,6 +584,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = False, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -627,6 +629,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) if self.step_index is None: self._init_step_index(timestep) @@ -668,7 +675,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - + if _use_cfgpp: + prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1] + # denoised = sample - model_output * sigmas[i] # d = (sample - denoised) / sigmas[i] # new_sample = denoised + d * sigmas[i + 1]