From 8e1b7a084addc4711b8d9be2738441dfad680ce0 Mon Sep 17 00:00:00 2001 From: spacepxl <143970342+spacepxl@users.noreply.github.com> Date: Sun, 16 Jun 2024 15:52:33 -0400 Subject: [PATCH] Fix the deletion of SD3 text encoders for Dreambooth/LoRA training if the text encoders are not being trained (#8536) * Update train_dreambooth_sd3.py to fix TE garbage collection * Update train_dreambooth_lora_sd3.py to fix TE garbage collection --------- Co-authored-by: Kashif Rasul Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_sd3.py | 3 +++ examples/dreambooth/train_dreambooth_sd3.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 67227e2def..8f831b50d0 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1316,6 +1316,9 @@ def main(args): # Clear the memory here if not train_dataset.custom_instance_prompts: del tokenizers, text_encoders + # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection + del tokenizer_one, tokenizer_two, tokenizer_three + del text_encoder_one, text_encoder_two, text_encoder_three gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 7920b4c8e0..9f89f3a48d 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1347,6 +1347,9 @@ def main(args): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders + # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection + del tokenizer_one, tokenizer_two, tokenizer_three + del text_encoder_one, text_encoder_two, text_encoder_three gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()