From 1acc6786e5ead6e0b0e81e1cf5eae79173de97cf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 13:49:21 +0200 Subject: [PATCH] fix autocast --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9064410a45..eccf9f1425 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -363,7 +363,7 @@ def main(): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext + context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext with context: images = pipeline(example["prompt"]).images all_images.extend(images)