diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6d28536243..f6575f7494 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -486,6 +486,9 @@ def main(): lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 3dab86a484..81dbd0f1f6 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -706,6 +706,12 @@ def main(args): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: