1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Flax SD finetune] Fix dtype (#1038)

fix jnp dtype
This commit is contained in:
Duong A. Nguyen
2022-10-28 16:21:34 +07:00
committed by GitHub
parent fb38bb1621
commit 1e07b6b334

View File

@@ -371,11 +371,11 @@ def main():
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
)
weight_dtype = torch.float32
weight_dtype = jnp.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
weight_dtype = jnp.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
weight_dtype = jnp.bfloat16
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")