mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
support cfgpp in ddim
This commit is contained in:
@@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from typing import Dict, Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -84,15 +83,6 @@ class CFGPlusPlusGuidance(BaseGuidance):
|
||||
|
||||
return pred
|
||||
|
||||
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
if self._is_cfgpp_enabled():
|
||||
# TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later!
|
||||
pred_cond = self._preds["pred_cond"]
|
||||
pred_uncond = self._preds["pred_uncond"]
|
||||
diff = pred_uncond - pred_cond
|
||||
pred = pred + diff * self.guidance_scale * self._sigma_next
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 0
|
||||
@@ -104,6 +94,14 @@ class CFGPlusPlusGuidance(BaseGuidance):
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
scheduler_step_kwargs = {}
|
||||
if self._is_cfgpp_enabled():
|
||||
scheduler_step_kwargs["_use_cfgpp"] = True
|
||||
scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond")
|
||||
return self._preds, scheduler_step_kwargs
|
||||
|
||||
def _is_cfgpp_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
@@ -37,8 +37,6 @@ class BaseGuidance:
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._sigma: torch.Tensor = None
|
||||
self._sigma_next: torch.Tensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
self._enabled = True
|
||||
@@ -63,12 +61,10 @@ class BaseGuidance:
|
||||
def _force_enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None:
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._sigma = sigma
|
||||
self._sigma_next = sigma_next
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
@@ -95,9 +91,6 @@ class BaseGuidance:
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
return pred
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
@@ -112,7 +105,7 @@ class BaseGuidance:
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
return self._preds
|
||||
return self._preds, {}
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
@@ -2241,7 +2241,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -2295,13 +2295,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
|
||||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
|
||||
|
||||
# Perform guidance
|
||||
outputs = pipeline.guider.outputs
|
||||
outputs, scheduler_step_kwargs = pipeline.guider.outputs
|
||||
data.noise_pred = pipeline.guider(**outputs)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -2638,7 +2637,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
# (5) Denoise loop
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -2725,13 +2724,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
|
||||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
|
||||
|
||||
# Perform guidance
|
||||
outputs = pipeline.guider.outputs
|
||||
outputs, scheduler_step_kwargs = pipeline.guider.outputs
|
||||
data.noise_pred = pipeline.guider(**outputs)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -3055,7 +3053,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
|
||||
with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(data.timesteps):
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1])
|
||||
pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t)
|
||||
|
||||
(
|
||||
latents,
|
||||
@@ -3144,13 +3142,12 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
|
||||
data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred)
|
||||
|
||||
# Perform guidance
|
||||
outputs = pipeline.guider.outputs
|
||||
outputs, scheduler_step_kwargs = pipeline.guider.outputs
|
||||
data.noise_pred = pipeline.guider(**outputs)
|
||||
|
||||
# Perform scheduler step using the predicted output
|
||||
data.latents_dtype = data.latents.dtype
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
|
||||
data.latents = pipeline.guider.post_scheduler_step(data.latents)
|
||||
data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if data.latents.dtype != data.latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
@@ -349,6 +349,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
_model_output_uncond: Optional[torch.Tensor] = None,
|
||||
_use_cfgpp: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
@@ -386,6 +388,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if _use_cfgpp and self.config.prediction_type != "epsilon":
|
||||
raise ValueError(
|
||||
f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}."
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
@@ -411,7 +418,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
pred_epsilon = model_output
|
||||
pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
@@ -584,6 +584,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
_model_output_uncond: Optional[torch.Tensor] = None,
|
||||
_use_cfgpp: bool = False,
|
||||
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
@@ -627,6 +629,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if _use_cfgpp and self.config.prediction_type != "epsilon":
|
||||
raise ValueError(
|
||||
f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
@@ -668,7 +675,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if _use_cfgpp:
|
||||
prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1]
|
||||
|
||||
# denoised = sample - model_output * sigmas[i]
|
||||
# d = (sample - denoised) / sigmas[i]
|
||||
# new_sample = denoised + d * sigmas[i + 1]
|
||||
|
||||
Reference in New Issue
Block a user