diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 846dd3eda4..fe45f2a5fa 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import PIL.Image from accelerate import Accelerator from datasets import load_dataset -from diffusers import DDPM, DDPMScheduler, UNetModel +from diffusers import DDPMPipeline, DDPMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel @@ -71,7 +71,7 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) - ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4) + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -133,7 +133,7 @@ def main(args): # Generate a sample image for visual inspection if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDPM( + pipeline = DDPMPipeline( unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler ) @@ -172,6 +172,9 @@ if __name__ == "__main__": parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3/4) + parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5dea0b22b3..d908850dfe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): - if timesteps.dim() != 1: - raise ValueError("`timesteps` must be a 1D tensor") - - device = original_samples.device - batch_size = original_samples.shape[0] - timesteps = timesteps.reshape(batch_size, 1, 1, 1) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a6f317852d..4cfbc5e59d 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,6 +14,8 @@ import numpy as np import torch +from typing import Union + SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -50,3 +52,29 @@ class SchedulerMixin: return torch.log(tensor) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape( + self, + values: Union[np.ndarray, torch.Tensor], + broadcast_array: Union[np.ndarray, torch.Tensor] + ): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + timesteps: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pt") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + if tensor_format == "pt": + values = values.to(broadcast_array.device) + + return values