diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index e72df2e150..b7ac105c24 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = { "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"), } WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] +TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} def import_model_class_from_model_name_or_path( @@ -100,6 +101,16 @@ def parse_args(): default=None, help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", ) + parser.add_argument( + "--vae_precision", + type="choice", + choices=["fp32", "fp16", "bf16"], + default="fp32", + help=( + "The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution" + " to this problem, and this flag allows you to use mixed precision to stabilize training." + ), + ) parser.add_argument( "--revision", type=str, @@ -878,7 +889,7 @@ def main(): if args.pretrained_vae_model_name_or_path is not None: vae.to(accelerator.device, dtype=weight_dtype) else: - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision]) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)