From cfdfcf20181974243cdd47d8f9c781d98aede058 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 5 Sep 2023 00:04:06 -0700 Subject: [PATCH] =?UTF-8?q?Add=20--vae=5Fprecision=20option=20to=20the=20S?= =?UTF-8?q?DXL=20pix2pix=20script=20so=20that=20we=20have=E2=80=A6=20(#488?= =?UTF-8?q?1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add --vae_precision option to the SDXL pix2pix script so that we have the option of avoiding float32 overhead * style --------- Co-authored-by: bghira --- .../instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index e72df2e150..b7ac105c24 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = { "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"), } WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] +TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} def import_model_class_from_model_name_or_path( @@ -100,6 +101,16 @@ def parse_args(): default=None, help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", ) + parser.add_argument( + "--vae_precision", + type="choice", + choices=["fp32", "fp16", "bf16"], + default="fp32", + help=( + "The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution" + " to this problem, and this flag allows you to use mixed precision to stabilize training." + ), + ) parser.add_argument( "--revision", type=str, @@ -878,7 +889,7 @@ def main(): if args.pretrained_vae_model_name_or_path is not None: vae.to(accelerator.device, dtype=weight_dtype) else: - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision]) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)