From 2d43094ffc9b1ee377651c6c8a358c81f0c96005 Mon Sep 17 00:00:00 2001 From: mwkldeveloper Date: Sun, 24 Dec 2023 17:04:35 +0800 Subject: [PATCH] fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same in train_text_to_image_lora.py (#6259) * fix RuntimeError: Input type (float) and bias type (c10::Half) should be the same * format source code * format code * remove the autocast blocks within the pipeline * add autocast blocks to pipeline caller in train_text_to_image_lora.py --- .../text_to_image/train_text_to_image_lora.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d6d0dee088..2efbaf298d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -847,10 +847,11 @@ def main(): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -916,8 +917,11 @@ def main(): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if len(images) != 0: