1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388) (#6402)

* Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388)

* Fix gradient-checkpointing option is ignored in SD+LoRA training.

* Fix gradient checkpoint is not applied to text encoders. (SDXL+LoRA)

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
2510
2024-01-01 12:21:04 +09:00
committed by GitHub
parent 61d223c884
commit 8a366b835c
2 changed files with 9 additions and 0 deletions

View File

@@ -486,6 +486,9 @@ def main():
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:

View File

@@ -706,6 +706,12 @@ def main(args):
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: