mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[train_text2image] Fix EMA and make it compatible with deepspeed. (#813)
* fix ema * style * add comment about copy * style * quality
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -234,25 +233,17 @@ dataset_name_mapping = {
|
||||
}
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
decay=0.9999,
|
||||
device=None,
|
||||
):
|
||||
self.averaged_model = copy.deepcopy(model).eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
|
||||
parameters = list(parameters)
|
||||
self.shadow_params = [p.clone().detach() for p in parameters]
|
||||
|
||||
self.decay = decay
|
||||
|
||||
if device is not None:
|
||||
self.averaged_model = self.averaged_model.to(device=device)
|
||||
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
@@ -263,34 +254,47 @@ class EMAModel:
|
||||
return 1 - min(self.decay, value)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
ema_state_dict = self.averaged_model.state_dict()
|
||||
def step(self, parameters):
|
||||
parameters = list(parameters)
|
||||
|
||||
self.optimization_step += 1
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
for key, param in new_model.named_parameters():
|
||||
if isinstance(param, dict):
|
||||
continue
|
||||
try:
|
||||
ema_param = ema_state_dict[key]
|
||||
except KeyError:
|
||||
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
||||
ema_state_dict[key] = ema_param
|
||||
|
||||
param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)
|
||||
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
ema_state_dict[key].sub_(self.decay * (ema_param - param))
|
||||
tmp = self.decay * (s_param - param)
|
||||
s_param.sub_(tmp)
|
||||
else:
|
||||
ema_state_dict[key].copy_(param)
|
||||
s_param.copy_(param)
|
||||
|
||||
for key, param in new_model.named_buffers():
|
||||
ema_state_dict[key] = param
|
||||
|
||||
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
"""
|
||||
Copy current averaged parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
parameters = list(parameters)
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def to(self, device=None, dtype=None) -> None:
|
||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||
|
||||
Args:
|
||||
device: like `device` argument to `torch.Tensor.to`
|
||||
"""
|
||||
# .to() on the tensors handles None correctly
|
||||
self.shadow_params = [
|
||||
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
||||
for p in self.shadow_params
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
@@ -336,9 +340,6 @@ def main():
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet)
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -510,8 +511,9 @@ def main():
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Move the ema_unet to gpu.
|
||||
ema_unet.averaged_model.to(accelerator.device)
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -583,7 +585,7 @@ def main():
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet)
|
||||
ema_unet.step(unet.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
@@ -598,10 +600,14 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
|
||||
Reference in New Issue
Block a user