From cc2c4ae759c915e76f020a5ddd1764b8063dc79d Mon Sep 17 00:00:00 2001 From: Pu Cao <48318302+caopulan@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:08:01 +0800 Subject: [PATCH] fix inference in custom diffusion (#5329) * Update train_custom_diffusion.py * make style * Empty-Commit --------- Co-authored-by: Sayak Paul --- .../train_custom_diffusion.py | 82 ++++++++++--------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 8d90998700..4773446a61 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1214,50 +1214,52 @@ def main(args): if global_step >= args.max_train_steps: break - if accelerator.is_main_process: - images = [] + if accelerator.is_main_process: + images = [] - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0] - for _ in range(args.num_validation_images) - ] + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[ + 0 + ] + for _ in range(args.num_validation_images) + ] - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) - del pipeline - torch.cuda.empty_cache() + del pipeline + torch.cuda.empty_cache() # Save the custom diffusion layers accelerator.wait_for_everyone()