From 1e07b6b334c0b94e06ca9860641bd2e16aff90fb Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:21:34 +0700 Subject: [PATCH] [Flax SD finetune] Fix dtype (#1038) fix jnp dtype --- examples/text_to_image/train_text_to_image_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index ab8d2b7ee2..cacfacef49 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -371,11 +371,11 @@ def main(): train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True ) - weight_dtype = torch.float32 + weight_dtype = jnp.float32 if args.mixed_precision == "fp16": - weight_dtype = torch.float16 + weight_dtype = jnp.float16 elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + weight_dtype = jnp.bfloat16 # Load models and create wrapper for stable diffusion tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")