1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add --vae_precision option to the SDXL pix2pix script so that we have… (#4881)

* Add --vae_precision option to the SDXL pix2pix script so that we have the option of avoiding float32 overhead

* style

---------

Co-authored-by: bghira <bghira@users.github.com>
This commit is contained in:
Bagheera
2023-09-05 00:04:06 -07:00
committed by GitHub
parent e4b8e7928b
commit cfdfcf2018

View File

@@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = {
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
}
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def import_model_class_from_model_name_or_path(
@@ -100,6 +101,16 @@ def parse_args():
default=None,
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
)
parser.add_argument(
"--vae_precision",
type="choice",
choices=["fp32", "fp16", "bf16"],
default="fp32",
help=(
"The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
" to this problem, and this flag allows you to use mixed precision to stabilize training."
),
)
parser.add_argument(
"--revision",
type=str,
@@ -878,7 +889,7 @@ def main():
if args.pretrained_vae_model_name_or_path is not None:
vae.to(accelerator.device, dtype=weight_dtype)
else:
vae.to(accelerator.device, dtype=torch.float32)
vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)