mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add --vae_precision option to the SDXL pix2pix script so that we have… (#4881)
* Add --vae_precision option to the SDXL pix2pix script so that we have the option of avoiding float32 overhead * style --------- Co-authored-by: bghira <bghira@users.github.com>
This commit is contained in:
@@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = {
|
||||
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
|
||||
}
|
||||
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
@@ -100,6 +101,16 @@ def parse_args():
|
||||
default=None,
|
||||
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_precision",
|
||||
type="choice",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
default="fp32",
|
||||
help=(
|
||||
"The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
|
||||
" to this problem, and this flag allows you to use mixed precision to stabilize training."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -878,7 +889,7 @@ def main():
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
Reference in New Issue
Block a user