mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
train_unconditional save restore unet parameters (#2706)
This commit is contained in:
@@ -625,8 +625,11 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
unet = accelerator.unwrap_model(model)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.store(unet.parameters())
|
||||
ema_model.copy_to(unet.parameters())
|
||||
|
||||
pipeline = DDPMPipeline(
|
||||
unet=unet,
|
||||
scheduler=noise_scheduler,
|
||||
@@ -641,6 +644,9 @@ def main(args):
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
||||
@@ -659,7 +665,22 @@ def main(args):
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
unet = accelerator.unwrap_model(model)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.store(unet.parameters())
|
||||
ema_model.copy_to(unet.parameters())
|
||||
|
||||
pipeline = DDPMPipeline(
|
||||
unet=unet,
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user