diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9064410a45..eccf9f1425 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -363,7 +363,7 @@ def main(): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext + context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext with context: images = pipeline(example["prompt"]).images all_images.extend(images)