diff --git a/README.md b/README.md index ebc678da2a..f60af578cd 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ `diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases. It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case. -Both models and scredulers should be load- and saveable from the Hub. +Both models and schedulers should be load- and saveable from the Hub. -Example: +Example for [DDPM](https://arxiv.org/abs/2006.11239): ```python import torch @@ -32,65 +32,91 @@ from diffusers import UNetModel, GaussianDDPMScheduler import PIL import numpy as np -generator = torch.Generator() -generator = generator.manual_seed(6694729458485568) +generator = torch.manual_seed(0) torch_device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load models -scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") +noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise -image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) +image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) # 3. Denoise -for t in reversed(range(len(scheduler))): - # 1. predict noise residual - with torch.no_grad(): - pred_noise_t = self.unet(image, t) +num_prediction_steps = len(noise_scheduler) +for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): + # predict noise residual + with torch.no_grad(): + residual = self.unet(image, t) - # 2. compute alphas, betas - alpha_prod_t = scheduler.get_alpha_prod(t) - alpha_prod_t_prev = scheduler.get_alpha_prod(t - 1) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev + # predict previous mean of image x_t-1 + pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t) - # 3. compute predicted image from residual - # First: compute predicted original image from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() + # optionally sample variance + variance = 0 + if t > 0: + noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + variance = noise_scheduler.get_variance(t).sqrt() * noise - # Second: Clip "predicted x_0" - pred_original_image = torch.clamp(pred_original_image, -1, 1) + # set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance - # Third: Compute coefficients for pred_original_image x_0 and current image x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * scheduler.get_beta(t)) / beta_prod_t - current_image_coeff = scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t - # Fourth: Compute predicted previous image µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image - - # 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous image - # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image - if t > 0: - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt() - noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator) - prev_image = pred_prev_image + variance * noise - else: - prev_image = pred_prev_image - - # 6. Set current image to prev_image: x_t -> x_t-1 - image = prev_image - -# process image to PIL +# 5. process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = (image_processed + 1.0) * 127.5 image_processed = image_processed.numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) -# save image +# 6. save image +image_pil.save("test.png") +``` + +Example for [DDIM](https://arxiv.org/abs/2010.02502): + +```python +import torch +from diffusers import UNetModel, DDIMScheduler +import PIL +import numpy as np + +generator = torch.manual_seed(0) +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + +# 1. Load models +noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq") +model = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) + +# 2. Sample gaussian noise +image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) + +# 3. Denoise +num_inference_steps = 50 +eta = 0.0 # <- deterministic sampling + +for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + # 1. predict noise residual + with torch.no_grad(): + residual = self.unet(image, inference_step_times[t]) + + # 2. predict previous mean of image x_t-1 + pred_prev_image = noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta) + + # 3. optionally sample variance + variance = 0 + if eta > 0: + noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + variance = noise_scheduler.get_variance(t).sqrt() * eta * noise + + # 4. set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance + +# 5. process image to PIL +image_processed = image.cpu().permute(0, 2, 3, 1) +image_processed = (image_processed + 1.0) * 127.5 +image_processed = image_processed.numpy().astype(np.uint8) +image_pil = PIL.Image.fromarray(image_processed[0]) + +# 6. save image image_pil.save("test.png") ``` diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index 9bbc7f7ce4..593f32ba6d 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -58,7 +58,7 @@ class DDIM(DiffusionPipeline): residual = self.unet(image, inference_step_times[t]) # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t, num_inference_steps, eta) + pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t, num_inference_steps, eta) # 3. optionally sample variance variance = 0 @@ -69,44 +69,4 @@ class DDIM(DiffusionPipeline): # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance - # 2. get actual t and t-1 -# train_step = inference_step_times[t] -# prev_train_step = inference_step_times[t - 1] if t > 0 else -1 -# - # 3. compute alphas, betas -# alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) -# alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) -# beta_prod_t = 1 - alpha_prod_t -# beta_prod_t_prev = 1 - alpha_prod_t_prev -# - # 4. Compute predicted previous image from predicted noise - # First: compute predicted original image from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf -# pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() -# - # Second: Clip "predicted x_0" -# pred_original_image = torch.clamp(pred_original_image, -1, 1) -# - # Third: Compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) -# std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() -# std_dev_t = eta * std_dev_t -# - # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf -# pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t -# - # Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf -# pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction -# - # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image - # Note: eta = 1.0 essentially corresponds to DDPM -# if eta > 0.0: -# noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) -# prev_image = pred_prev_image + std_dev_t * noise -# else: -# prev_image = pred_prev_image -# - # 6. Set current image to prev_image: x_t -> x_t-1 -# image = prev_image - return image diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 579408855a..986d9a39cd 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -39,20 +39,19 @@ class DDPM(DiffusionPipeline): ) num_prediction_steps = len(self.noise_scheduler) - for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # 1. predict noise residual with torch.no_grad(): residual = self.unet(image, t) # 2. predict previous mean of image x_t-1 - pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t) + pred_prev_image = self.noise_scheduler.get_prev_image_step(residual, image, t) # 3. optionally sample variance variance = 0 if t > 0: noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) - variance = self.noise_scheduler.get_variance(t) * noise + variance = self.noise_scheduler.get_variance(t).sqrt() * noise # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance diff --git a/src/diffusers/schedulers/ddim.py b/src/diffusers/schedulers/ddim.py index 283ab239a9..25f9030631 100644 --- a/src/diffusers/schedulers/ddim.py +++ b/src/diffusers/schedulers/ddim.py @@ -100,7 +100,7 @@ class DDIMScheduler(nn.Module, ConfigMixin): return variance - def predict_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False): + def get_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False): # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index 6b2439d9a8..8996665c86 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -47,6 +47,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ) self.num_timesteps = int(timesteps) self.clip_image = clip_predicted_image + self.variance_type = variance_type if beta_schedule == "linear": betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) @@ -97,11 +98,17 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous image # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t).sqrt() + variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)) + + # hacks - were probs added for training stability + if self.variance_type == "fixed_small": + variance = variance.clamp(min=1e-20) + elif self.variance_type == "fixed_large": + variance = self.get_beta(t) return variance - def predict_prev_image_step(self, residual, image, t, output_pred_x_0=False): + def get_prev_image_step(self, residual, image, t, output_pred_x_0=False): # 1. compute alphas, betas alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t_prev = self.get_alpha_prod(t - 1)