From 7c823c2ed72aed5dc0db24f368a3663302e71177 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 14 Dec 2022 02:35:41 -0800 Subject: [PATCH] manually update train_unconditional_ort (#1694) * manually update train_unconditional_ort * formatting Co-authored-by: Prathik Rao --- .../train_unconditional_ort.py | 292 ++++++++++++++---- 1 file changed, 231 insertions(+), 61 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py index 4e97732ade..b5974b84b3 100644 --- a/examples/unconditional_image_generation/train_unconditional_ort.py +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -1,4 +1,5 @@ import argparse +import inspect import math import os from pathlib import Path @@ -31,9 +32,192 @@ from tqdm.auto import tqdm # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") + logger = get_logger(__name__) +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + if not isinstance(arr, torch.Tensor): + arr = torch.from_numpy(arr) + res = arr[timesteps].float().to(timesteps.device) + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpm-model-64", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use Exponential Moving Average for the final model weights.", + ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + parser.add_argument( + "--prediction_type", + type=str, + default="epsilon", + choices=["epsilon", "sample"], + help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", + ) + + parser.add_argument("--ddpm_num_steps", type=int, default=1000) + parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() @@ -77,7 +261,17 @@ def main(args): ), ) model = ORTModule(model) - noise_scheduler = DDPMScheduler(num_train_timesteps=1000) + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + + if accepts_prediction_type: + noise_scheduler = DDPMScheduler( + num_train_timesteps=args.ddpm_num_steps, + beta_schedule=args.ddpm_beta_schedule, + prediction_type=args.prediction_type, + ) + else: + noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -101,7 +295,6 @@ def main(args): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, - use_auth_token=True if args.use_auth_token else None, split="train", ) else: @@ -111,8 +304,12 @@ def main(args): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} + logger.info(f"Dataset size: {len(dataset)}") + dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) lr_scheduler = get_scheduler( args.lr_scheduler, @@ -127,7 +324,12 @@ def main(args): num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + ema_model = EMAModel( + accelerator.unwrap_model(model), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) # Handle the repository creation if accelerator.is_main_process: @@ -171,11 +373,26 @@ def main(args): with accelerator.accumulate(model): # Predict the noise residual - noise_pred = model(noisy_images, timesteps, return_dict=True)[0] - loss = F.mse_loss(noise_pred, noise) + model_output = model(noisy_images, timesteps, return_dict=True)[0] + + if args.prediction_type == "epsilon": + loss = F.mse_loss(model_output, noise) # this could have different weights! + elif args.prediction_type == "sample": + alpha_t = _extract_into_tensor( + noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) + ) + snr_weights = alpha_t / (1 - alpha_t) + loss = snr_weights * F.mse_loss( + model_output, clean_images, reduction="none" + ) # use SNR weighting from distillation paper + loss = loss.mean() + else: + raise ValueError(f"Unsupported prediction type: {args.prediction_type}") + accelerator.backward(loss) - accelerator.clip_grad_norm_(model.parameters(), 1.0) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() if args.use_ema: @@ -204,9 +421,13 @@ def main(args): scheduler=noise_scheduler, ) - generator = torch.manual_seed(0) + generator = torch.Generator(device=pipeline.device).manual_seed(0) # run pipeline in inference (sample random noise and denoise) - images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images + images = pipeline( + generator=generator, + batch_size=args.eval_batch_size, + output_type="numpy", + ).images # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") @@ -225,56 +446,5 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset_name", type=str, default=None) - parser.add_argument("--dataset_config_name", type=str, default=None) - parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") - parser.add_argument("--output_dir", type=str, default="ddpm-model-64") - parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--cache_dir", type=str, default=None) - parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--train_batch_size", type=int, default=16) - parser.add_argument("--eval_batch_size", type=int, default=16) - parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--save_images_epochs", type=int, default=10) - parser.add_argument("--save_model_epochs", type=int, default=10) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=1e-4) - parser.add_argument("--lr_scheduler", type=str, default="cosine") - parser.add_argument("--lr_warmup_steps", type=int, default=500) - parser.add_argument("--adam_beta1", type=float, default=0.95) - parser.add_argument("--adam_beta2", type=float, default=0.999) - parser.add_argument("--adam_weight_decay", type=float, default=1e-6) - parser.add_argument("--adam_epsilon", type=float, default=1e-08) - parser.add_argument("--use_ema", action="store_true", default=True) - parser.add_argument("--ema_inv_gamma", type=float, default=1.0) - parser.add_argument("--ema_power", type=float, default=3 / 4) - parser.add_argument("--ema_max_decay", type=float, default=0.9999) - parser.add_argument("--push_to_hub", action="store_true") - parser.add_argument("--use_auth_token", action="store_true") - parser.add_argument("--hub_token", type=str, default=None) - parser.add_argument("--hub_model_id", type=str, default=None) - parser.add_argument("--hub_private_repo", action="store_true") - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." - ), - ) - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - if args.dataset_name is None and args.train_data_dir is None: - raise ValueError("You must specify either a dataset name from the hub or a train data directory.") - + args = parse_args() main(args)