diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 377f226150..ddc3a60876 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -247,6 +247,16 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." @@ -436,6 +446,12 @@ def main(args): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype,