1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[train_text2image] Fix EMA and make it compatible with deepspeed. (#813)

* fix ema

* style

* add comment about copy

* style

* quality
This commit is contained in:
Suraj Patil
2022-10-12 19:13:22 +02:00
committed by GitHub
parent 5afc2b60cd
commit 008b608f15

View File

@@ -1,11 +1,10 @@
import argparse
import copy
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional
from typing import Iterable, Optional
import numpy as np
import torch
@@ -234,25 +233,17 @@ dataset_name_mapping = {
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
model,
decay=0.9999,
device=None,
):
self.averaged_model = copy.deepcopy(model).eval()
self.averaged_model.requires_grad_(False)
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.decay = decay
if device is not None:
self.averaged_model = self.averaged_model.to(device=device)
self.optimization_step = 0
def get_decay(self, optimization_step):
@@ -263,34 +254,47 @@ class EMAModel:
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, new_model):
ema_state_dict = self.averaged_model.state_dict()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
try:
ema_param = ema_state_dict[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_state_dict[key] = ema_param
param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
ema_state_dict[key].sub_(self.decay * (ema_param - param))
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
ema_state_dict[key].copy_(param)
s_param.copy_(param)
for key, param in new_model.named_buffers():
ema_state_dict[key] = param
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def main():
args = parse_args()
@@ -336,9 +340,6 @@ def main():
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
if args.use_ema:
ema_unet = EMAModel(unet)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -510,8 +511,9 @@ def main():
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# Move the ema_unet to gpu.
ema_unet.averaged_model.to(accelerator.device)
# Create EMA for the unet.
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -583,7 +585,7 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet)
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -598,10 +600,14 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True