mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add grad ckpt
This commit is contained in:
@@ -120,6 +120,11 @@ def parse_args():
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
@@ -388,10 +393,14 @@ def main():
|
||||
args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
unet.parameters(), # only optimize unet
|
||||
lr=args.learning_rate,
|
||||
|
||||
Reference in New Issue
Block a user