diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 254d9e6a0f..c61c2ae44c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -781,7 +781,10 @@ def main(): args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, + unet=unet, + vae=vae, revision=args.revision, + torch_dtype=weight_dtype, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) @@ -791,8 +794,11 @@ def main(): generator = ( None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) ) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard":