1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix autocast

This commit is contained in:
patil-suraj
2022-09-26 13:49:21 +02:00
parent 195cd463a8
commit 1acc6786e5

View File

@@ -363,7 +363,7 @@ def main():
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
with context:
images = pipeline(example["prompt"]).images
all_images.extend(images)