diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e4a91ff5c8..2a29b7fa2d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,11 +1,10 @@ import argparse -import copy import logging import math import os import random from pathlib import Path -from typing import Optional +from typing import Iterable, Optional import numpy as np import torch @@ -234,25 +233,17 @@ dataset_name_mapping = { } +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ - def __init__( - self, - model, - decay=0.9999, - device=None, - ): - self.averaged_model = copy.deepcopy(model).eval() - self.averaged_model.requires_grad_(False) + def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] self.decay = decay - - if device is not None: - self.averaged_model = self.averaged_model.to(device=device) - self.optimization_step = 0 def get_decay(self, optimization_step): @@ -263,34 +254,47 @@ class EMAModel: return 1 - min(self.decay, value) @torch.no_grad() - def step(self, new_model): - ema_state_dict = self.averaged_model.state_dict() + def step(self, parameters): + parameters = list(parameters) self.optimization_step += 1 self.decay = self.get_decay(self.optimization_step) - for key, param in new_model.named_parameters(): - if isinstance(param, dict): - continue - try: - ema_param = ema_state_dict[key] - except KeyError: - ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) - ema_state_dict[key] = ema_param - - param = param.clone().detach().to(ema_param.dtype).to(ema_param.device) - + for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - ema_state_dict[key].sub_(self.decay * (ema_param - param)) + tmp = self.decay * (s_param - param) + s_param.sub_(tmp) else: - ema_state_dict[key].copy_(param) + s_param.copy_(param) - for key, param in new_model.named_buffers(): - ema_state_dict[key] = param - - self.averaged_model.load_state_dict(ema_state_dict, strict=False) torch.cuda.empty_cache() + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + def main(): args = parse_args() @@ -336,9 +340,6 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") - if args.use_ema: - ema_unet = EMAModel(unet) - # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -510,8 +511,9 @@ def main(): text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Move the ema_unet to gpu. - ema_unet.averaged_model.to(accelerator.device) + # Create EMA for the unet. + if args.use_ema: + ema_unet = EMAModel(unet.parameters()) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -583,7 +585,7 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: - ema_unet.step(unet) + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) @@ -598,10 +600,14 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, - unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet), + unet=unet, tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True