mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388) * Fix gradient-checkpointing option is ignored in SD+LoRA training. * Fix gradient checkpoint is not applied to text encoders. (SDXL+LoRA) --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -486,6 +486,9 @@ def main():
|
||||
|
||||
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
|
||||
@@ -706,6 +706,12 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
|
||||
Reference in New Issue
Block a user