mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix autocast
This commit is contained in:
@@ -363,7 +363,7 @@ def main():
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext
|
||||
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
|
||||
with context:
|
||||
images = pipeline(example["prompt"]).images
|
||||
all_images.extend(images)
|
||||
|
||||
Reference in New Issue
Block a user