mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[BUG] Fix Autoencoderkl train script (#11113)
* add disc_optimizer step (not fix) * support syncbatchnorm in discriminator
This commit is contained in:
@@ -627,6 +627,7 @@ def main(args):
|
||||
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
|
||||
perceptual_loss = lpips.LPIPS(net="vgg").eval()
|
||||
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
|
||||
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
|
||||
|
||||
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
|
||||
def unwrap_model(model):
|
||||
@@ -951,13 +952,20 @@ def main(args):
|
||||
logits_fake = discriminator(reconstructions)
|
||||
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
|
||||
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
|
||||
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
|
||||
logs = {
|
||||
"disc_loss": disc_loss.detach().mean().item(),
|
||||
"disc_loss": d_loss.detach().mean().item(),
|
||||
"logits_real": logits_real.detach().mean().item(),
|
||||
"logits_fake": logits_fake.detach().mean().item(),
|
||||
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
accelerator.backward(d_loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = discriminator.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
disc_optimizer.step()
|
||||
disc_lr_scheduler.step()
|
||||
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
Reference in New Issue
Block a user