mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add support for different model prediction types in DDIMInverseScheduler
Resolve alpha_prod_t_prev index issue for final step of inversion
This commit is contained in:
@@ -1156,8 +1156,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
# 7. Denoising loop where we obtain the cross-attention maps.
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
|
||||
for i, t in enumerate(timesteps[1:-1]):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -122,7 +122,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
set_alpha_to_zero: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
@@ -144,11 +144,12 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
# At every step in inverted ddim, we are looking into the next alphas_cumprod
|
||||
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
|
||||
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
|
||||
# in this case, self.step() just normalizes output by self.config.prediction_type
|
||||
# or whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
@@ -157,6 +158,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
|
||||
|
||||
# Copy from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -205,23 +207,44 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
e_t = model_output
|
||||
|
||||
x = sample
|
||||
# 1. get previous step value (=t+1)
|
||||
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
a_t = self.alphas_cumprod[timestep - 1]
|
||||
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = (
|
||||
self.alphas_cumprod[prev_timestep]
|
||||
if prev_timestep < self.config.num_train_timesteps
|
||||
else self.final_alpha_cumprod
|
||||
)
|
||||
|
||||
pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
dir_xt = (1.0 - a_prev).sqrt() * e_t
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
# 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
|
||||
|
||||
# 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_x0)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
|
||||
return (prev_sample, pred_original_sample)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user