From 16ecc089e3ce400a27ec5ce91e72da0c706dd8a6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 15:10:54 +0200 Subject: [PATCH] add grad ckpt --- examples/dreambooth/train_dreambooth.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3f8413415..c44994bbf5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -120,6 +120,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -388,10 +393,14 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token ) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + optimizer = torch.optim.AdamW( unet.parameters(), # only optimize unet lr=args.learning_rate,