From c1971a53bc8bb95e36e7135351065ed3814c8532 Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Wed, 8 Feb 2023 18:37:10 +0900 Subject: [PATCH] Textual inv save log memory (#2184) * Quality check and adding tokenizer * Adapted stable diffusion to mixed precision+finished up style fixes * Fixed based on patrick's review * Fixed oom from number of validation images * Removed unnecessary np.array conversion --------- Co-authored-by: Patrick von Platen --- examples/textual_inversion/textual_inversion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 254d9e6a0f..c61c2ae44c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -781,7 +781,10 @@ def main(): args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, + unet=unet, + vae=vae, revision=args.revision, + torch_dtype=weight_dtype, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) @@ -791,8 +794,11 @@ def main(): generator = ( None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) ) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard":