diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3f8413415..c44994bbf5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -120,6 +120,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -388,10 +393,14 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token ) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + optimizer = torch.optim.AdamW( unet.parameters(), # only optimize unet lr=args.learning_rate,