1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[train_unconditional] fix gradient accumulation. (#308)

fix grad accum
This commit is contained in:
Suraj Patil
2022-09-01 19:32:15 +05:30
committed by GitHub
parent 4724250980
commit 1b1d6444c6

View File

@@ -1,4 +1,5 @@
import argparse
import math
import os
import torch
@@ -29,6 +30,7 @@ logger = get_logger(__name__)
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
@@ -105,6 +107,8 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub:
@@ -117,7 +121,7 @@ def main(args):
global_step = 0
for epoch in range(args.num_epochs):
model.train()
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
@@ -146,13 +150,16 @@ def main(args):
ema_model.step(model)
optimizer.zero_grad()
progress_bar.update(1)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
progress_bar.close()
accelerator.wait_for_everyone()