diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d6d0dee088..2efbaf298d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -847,10 +847,11 @@ def main(): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -916,8 +917,11 @@ def main(): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if len(images) != 0: