diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 39376f3699..e26fcbdd4c 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel): self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) - def forward(self, sample, step_value, transformer_out=None): - timesteps = step_value + def forward(self, sample, timestep, transformer_out=None): + timesteps = timestep x = sample hs = [] emb = self.time_embed( @@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel): resblock_updown=resblock_updown, ) - def forward(self, sample, step_value, low_res=None): - timesteps = step_value + def forward(self, sample, timestep, low_res=None): + timesteps = timestep x = sample _, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 339a9f6550..198af814a7 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin): self.all_modules = nn.ModuleList(modules) - def forward(self, sample, step_value, sigmas=None): - timesteps = step_value + def forward(self, sample, timestep, sigmas=None): + timesteps = timestep x = sample # timestep/noise_level embedding; only for continuous training modules = self.all_modules diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py index 5495ccd6ea..ad738a2737 100644 --- a/src/diffusers/models/unet_unconditional.py +++ b/src/diffusers/models/unet_unconditional.py @@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): # ====================================== def forward( - self, sample: torch.FloatTensor, step_value: Union[torch.Tensor, float, int] + self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] ) -> Dict[str, torch.FloatTensor]: # TODO(PVP) - to delete later at release # IMPORTANT: NOT RELEVANT WHEN REVIEWING API @@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): self.set_weights() # ====================================== - # 1. time step embeddings - timesteps = step_value + # 1. time step embeddings -> make correct tensor + timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) t_emb = get_timestep_embedding( timesteps, diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 2d7e72c77b..d0e920febb 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline class DDIMPipeline(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, scheduler): super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - num_trained_timesteps = self.noise_scheduler.config.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - self.unet.to(torch_device) # Sample gaussian noise to begin loop @@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline): ) image = image.to(torch_device) - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding + # set step values + self.scheduler.set_timesteps(num_inference_steps) - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" - for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + for t in tqdm.tqdm(self.scheduler.timesteps): # 1. predict noise residual with torch.no_grad(): - residual = self.unet(image, inference_step_times[t]) + residual = self.unet(image, t) if isinstance(residual, dict): residual = residual["sample"] - # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # do x_t -> x_t-1 + image = self.scheduler.step(residual, t, image, eta)["prev_sample"] - # 3. optionally sample variance - variance = 0 - if eta > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise - - # 4. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance - - return image + return {"sample": image} diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 2d18e43c9c..d0949b27bb 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline class DDPMPipeline(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, scheduler): super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None): if torch_device is None: @@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline): ) image = image.to(torch_device) - num_prediction_steps = len(self.noise_scheduler) + num_prediction_steps = len(self.scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # 1. predict noise residual with torch.no_grad(): @@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline): residual = residual["sample"] # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.step(residual, image, t) + pred_prev_image = self.scheduler.step(residual, t, image)["prev_sample"] # 3. optionally sample variance variance = 0 if t > 0: noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = self.noise_scheduler.get_variance(t).sqrt() * noise + variance = self.scheduler.get_variance(t).sqrt() * noise # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance diff --git a/src/diffusers/pipelines/glide/pipeline_glide.py b/src/diffusers/pipelines/glide/pipeline_glide.py index a9b6746f5c..c0ccb80477 100644 --- a/src/diffusers/pipelines/glide/pipeline_glide.py +++ b/src/diffusers/pipelines/glide/pipeline_glide.py @@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline): def __init__( self, text_unet: GlideTextToImageUNetModel, - text_noise_scheduler: DDPMScheduler, + text_scheduler: DDPMScheduler, text_encoder: CLIPTextModel, tokenizer: GPT2Tokenizer, upscale_unet: GlideSuperResUNetModel, - upscale_noise_scheduler: DDIMScheduler, + upscale_scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_unet=text_unet, - text_noise_scheduler=text_noise_scheduler, + text_scheduler=text_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, upscale_unet=upscale_unet, - upscale_noise_scheduler=upscale_noise_scheduler, + upscale_scheduler=upscale_scheduler, ) @torch.no_grad() @@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline): transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state # 3. Run the text2image generation step - num_prediction_steps = len(self.text_noise_scheduler) + num_prediction_steps = len(self.text_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): with torch.no_grad(): time_input = torch.tensor([t] * image.shape[0], device=torch_device) model_output = text_model_fn(image, time_input, transformer_out) noise_residual, model_var_values = torch.split(model_output, 3, dim=1) - min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log") - max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log") + min_log = self.text_scheduler.get_variance(t, "fixed_small_log") + max_log = self.text_scheduler.get_variance(t, "fixed_large_log") # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log - pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t) + pred_prev_image = self.text_scheduler.step(noise_residual, image, t) noise = torch.randn(image.shape, generator=generator).to(torch_device) variance = torch.exp(0.5 * model_log_variance) * noise @@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline): ).to(torch_device) image = image * upsample_temp - num_trained_timesteps = self.upscale_noise_scheduler.timesteps + num_trained_timesteps = self.upscale_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): @@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline): noise_residual, pred_variance = torch.split(model_output, 3, dim=1) # 2. predict previous mean of image x_t-1 - pred_prev_image = self.upscale_noise_scheduler.step( + pred_prev_image = self.upscale_scheduler.step( noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True ) @@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline): variance = 0 if eta > 0: noise = torch.randn(image.shape, generator=generator).to(torch_device) - variance = ( - self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise - ) + variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 6a4eb1621f..6711d2aba3 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel): class LatentDiffusionPipeline(DiffusionPipeline): - def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): + def __init__(self, vqvae, bert, tokenizer, unet, scheduler): super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( @@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_embedding = self.bert(text_input.input_ids) - num_trained_timesteps = self.noise_scheduler.config.timesteps + num_trained_timesteps = self.scheduler.config.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) image = torch.randn( @@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline): pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) + pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) # 3. optionally sample variance variance = 0 if eta > 0: noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise + variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 30aafc1e8e..d1918482e1 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline class LatentDiffusionUncondPipeline(DiffusionPipeline): - def __init__(self, vqvae, unet, noise_scheduler): + def __init__(self, vqvae, unet, scheduler): super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler) + scheduler = scheduler.set_format("pt") + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( @@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): self.unet.to(torch_device) self.vqvae.to(torch_device) - num_trained_timesteps = self.noise_scheduler.config.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - image = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), generator=generator, ).to(torch_device) - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding + self.scheduler.set_timesteps(num_inference_steps) - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" - for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - # 1. predict noise residual - timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) - pred_noise_t = self.unet(image, timesteps) + for t in tqdm.tqdm(self.scheduler.timesteps): + residual = self.unet(image, t) - if isinstance(pred_noise_t, dict): - pred_noise_t = pred_noise_t["sample"] + if isinstance(residual, dict): + residual = residual["sample"] - # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) - - # 3. optionally sample variance - variance = 0 - if eta > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise - - # 4. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # do x_t -> x_t-1 + image = self.scheduler.step(residual, t, image, eta)["prev_sample"] # decode image with vae image = self.vqvae.decode(image) - return image + return {"sample": image} diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 7dbff9b1b9..9dc7e92ede 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline class PNDMPipeline(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, scheduler): super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + scheduler = scheduler.set_format("pt") + self.register_modules(unet=unet, scheduler=scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): # For more information on the sampling method you can take a look at Algorithm 2 of @@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline): ) image = image.to(torch_device) - prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps) + prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps) for t in tqdm.tqdm(range(len(prk_time_steps))): t_orig = prk_time_steps[t] residual = self.unet(image, t_orig) @@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline): if isinstance(residual, dict): residual = residual["sample"] - image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) + image = self.scheduler.step_prk(residual, t, image, num_inference_steps)["prev_sample"] - timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) + timesteps = self.scheduler.get_time_steps(num_inference_steps) for t in tqdm.tqdm(range(len(timesteps))): t_orig = timesteps[t] residual = self.unet(image, t_orig) @@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline): if isinstance(residual, dict): residual = residual["sample"] - image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) + image = self.scheduler.step_plms(residual, t, image, num_inference_steps)["prev_sample"] return image diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 97b5880c0b..89c7a52914 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,8 +16,10 @@ # and https://github.com/hojonathanho/diffusion import math +from typing import Union import numpy as np +import torch from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.one = np.array(1.0) + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy() + + self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def get_variance(self, t, num_inference_steps): - orig_t = self.config.timesteps // num_inference_steps * t - orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1 - - alpha_prod_t = self.alphas_cumprod[orig_t] - alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False): + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, self.config.timesteps, self.config.timesteps // self.num_inference_steps)[ + ::-1 + ].copy() + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + residual: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + eta, + use_clipped_residual=False, + generator=None, + ): # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # - pred_sample_direction -> "direction pointingc to x_t" # - pred_prev_sample -> "x_t-1" - # 1. get actual t and t-1 - orig_t = self.config.timesteps // num_inference_steps * t - orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1 + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.timesteps // self.num_inference_steps # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[orig_t] - alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called @@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self.get_variance(t, num_inference_steps) + variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_residual: @@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - return pred_prev_sample + if eta > 0: + device = residual.device if torch.is_tensor(residual) else "cpu" + noise = torch.randn(residual.shape, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + if not torch.is_tensor(residual): + variance = variance.numpy() + + prev_sample = prev_sample + variance + + return {"prev_sample": prev_sample} def add_noise(self, original_samples, noise, timesteps): sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index f80476b9c5..034a0d4280 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,8 +15,10 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +from typing import Union import numpy as np +import torch from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return variance - def step(self, residual, sample, t, predict_epsilon=True): + def step( + self, + residual: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + predict_epsilon=True, + ): + t = timestep # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one @@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - return pred_prev_sample + return {"prev_sample": pred_prev_sample} def add_noise(self, original_samples, noise, timesteps): sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 8533ad6cd7..3e398d53b1 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,8 +15,10 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +from typing import Union import numpy as np +import torch from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): raise ValueError(f"mode {self.mode} does not exist.") - def step_prk(self, residual, sample, t, num_inference_steps): + def step_prk( + self, + residual: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + num_inference_steps, + ): + t = timestep prk_time_steps = self.get_prk_time_steps(num_inference_steps) t_orig = prk_time_steps[t // 4 * 4] @@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # cur_sample should not be `None` cur_sample = self.cur_sample if self.cur_sample is not None else sample - return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual) + return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)} - def step_plms(self, residual, sample, t, num_inference_steps): + def step_plms( + self, + residual: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + num_inference_steps, + ): + t = timestep if len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " @@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - return self.get_prev_sample(sample, t_orig, t_orig_prev, residual) + return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, residual)} def get_prev_sample(self, sample, t_orig, t_orig_prev, residual): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index a36946296e..1f694b458d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -165,7 +165,7 @@ class ModelTesterMixin: # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] - expected_arg_names = ["sample", "step_value"] + expected_arg_names = ["sample", "timestep"] self.assertListEqual(arg_names[:2], expected_arg_names) def test_model_from_config(self): @@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "step_value": time_step} + return {"sample": noise, "timestep": time_step} @property def input_shape(self): @@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - return {"sample": noise, "step_value": time_step, "low_res": low_res} + return {"sample": noise, "timestep": time_step, "low_res": low_res} @property def input_shape(self): @@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) - return {"sample": noise, "step_value": time_step, "transformer_out": emb} + return {"sample": noise, "timestep": time_step, "transformer_out": emb} @property def input_shape(self): @@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor([10]).to(torch_device) - return {"sample": noise, "step_value": time_step} + return {"sample": noise, "timestep": time_step} @property def input_shape(self): @@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(torch_device) - return {"sample": noise, "step_value": time_step} + return {"sample": noise, "timestep": time_step} @property def input_shape(self): @@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase): ddpm = DDPMPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) - ddpm.noise_scheduler.num_timesteps = 10 - ddpm_from_hub.noise_scheduler.num_timesteps = 10 + ddpm.scheduler.num_timesteps = 10 + ddpm_from_hub.scheduler.num_timesteps = 10 generator = torch.manual_seed(0) @@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase): model_id = "fusing/ddpm-cifar10" unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) - noise_scheduler = DDPMScheduler.from_config(model_id) - noise_scheduler = noise_scheduler.set_format("pt") + scheduler = DDPMScheduler.from_config(model_id) + scheduler = scheduler.set_format("pt") - ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) image = ddpm(generator=generator) @@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase): model_id = "fusing/ddpm-lsun-bedroom-ema" unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) - noise_scheduler = DDIMScheduler.from_config(model_id) - noise_scheduler = noise_scheduler.set_format("pt") + scheduler = DDIMScheduler.from_config(model_id) - ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = ddpm(generator=generator) + image = ddpm(generator=generator)["sample"] image_slice = image[0, -1, -3:, -3:].cpu() @@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase): model_id = "fusing/ddpm-cifar10" unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) - noise_scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler(tensor_format="pt") - ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) + ddim = DDIMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0) + image = ddim(generator=generator, eta=0.0)["sample"] image_slice = image[0, -1, -3:, -3:].cpu() @@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase): model_id = "fusing/ddpm-cifar10" unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) - noise_scheduler = PNDMScheduler(tensor_format="pt") + scheduler = PNDMScheduler(tensor_format="pt") - pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) + pndm = PNDMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) image = pndm(generator=generator) @@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase): ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True) generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5) + image = ldm(generator=generator, num_inference_steps=5)["sample"] image_slice = image[0, -1, -3:, -3:].cpu() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 829383daba..0603fa5ddd 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase): def check_over_configs(self, time_step=0, **config): kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample @@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step(residual, sample, time_step, **kwargs) - new_output = new_scheduler.step(residual, sample, time_step, **kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + new_scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase): kwargs = dict(self.forward_default_kwargs) kwargs.update(forward_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: sample = self.dummy_sample residual = 0.1 * sample @@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step(residual, sample, time_step, **kwargs) - new_output = new_scheduler.step(residual, sample, time_step, **kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + new_scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + torch.manual_seed(0) + output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + torch.manual_seed(0) + new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" def test_from_pretrained_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: sample = self.dummy_sample residual = 0.1 * sample @@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step(residual, sample, 1, **kwargs) - new_output = new_scheduler.step(residual, sample, 1, **kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + new_scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase): sample = self.dummy_sample residual = 0.1 * sample - output_0 = scheduler.step(residual, sample, 0, **kwargs) - output_1 = scheduler.step(residual, sample, 1, **kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"] + output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase): def test_pytorch_equal_numpy(self): kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: sample = self.dummy_sample residual = 0.1 * sample @@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase): scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - output = scheduler.step(residual, sample, 1, **kwargs) - output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + scheduler_pt.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" @@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): residual = model(sample, t) # 2. predict previous mean of sample x_t-1 - pred_prev_sample = scheduler.step(residual, sample, t) + pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"] if t > 0: noise = self.dummy_sample_deter @@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): class DDIMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDIMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50), ("eta", 0.0)) + forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50)) def get_scheduler_config(self, **kwargs): config = { @@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): return config def test_timesteps(self): - for timesteps in [1, 5, 100, 1000]: + for timesteps in [100, 500, 1000]: self.check_over_configs(timesteps=timesteps) def test_betas(self): @@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + self.check_over_forward(num_inference_steps=num_inference_steps) def test_eta(self): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): @@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5 - assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5 - assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5 - assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5 - assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5 - assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5 + assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - num_inference_steps, eta = 10, 0.1 - num_trained_timesteps = len(scheduler) - - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) + num_inference_steps, eta = 10, 0.0 model = self.dummy_model() sample = self.dummy_sample_deter - for t in reversed(range(num_inference_steps)): - residual = model(sample, inference_step_times[t]) + scheduler.set_timesteps(num_inference_steps) + for t in scheduler.timesteps: + residual = model(sample, t) - pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta) - - variance = 0 - if eta > 0: - noise = self.dummy_sample_deter - variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise - - sample = pred_prev_sample + variance + sample = scheduler.step(residual, t, sample, eta)["prev_sample"] result_sum = np.sum(np.abs(sample)) result_mean = np.mean(np.abs(sample)) - assert abs(result_sum.item() - 270.6214) < 1e-2 - assert abs(result_mean.item() - 0.3524) < 1e-3 + assert abs(result_sum.item() - 172.0067) < 1e-2 + assert abs(result_mean.item() - 0.223967) < 1e-3 class PNDMSchedulerTest(SchedulerCommonTest): @@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): new_scheduler.ets = dummy_past_residuals[:] new_scheduler.set_plms_mode() - output = scheduler.step(residual, sample, time_step, **kwargs) - new_output = new_scheduler.step(residual, sample, time_step, **kwargs) + output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): new_scheduler.ets = dummy_past_residuals[:] new_scheduler.set_plms_mode() - output = scheduler.step(residual, sample, time_step, **kwargs) - new_output = new_scheduler.step(residual, sample, time_step, **kwargs) + output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler.set_plms_mode() - scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50) + scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): t_orig = prk_time_steps[t] residual = model(sample, t_orig) - sample = scheduler.step_prk(residual, sample, t, num_inference_steps) + sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"] timesteps = scheduler.get_time_steps(num_inference_steps) for t in range(len(timesteps)): t_orig = timesteps[t] residual = model(sample, t_orig) - sample = scheduler.step_plms(residual, sample, t, num_inference_steps) + sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"] result_sum = np.sum(np.abs(sample)) result_mean = np.mean(np.abs(sample))