1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add grad ckpt

This commit is contained in:
patil-suraj
2022-09-26 15:10:54 +02:00
parent 2894a92f92
commit 16ecc089e3

View File

@@ -120,6 +120,11 @@ def parse_args():
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
@@ -388,10 +393,14 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token
)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
optimizer = torch.optim.AdamW(
unet.parameters(), # only optimize unet
lr=args.learning_rate,