From 677df5ac12fe6b159a9c041677eed71ae2d8d04a Mon Sep 17 00:00:00 2001 From: Shyam Marjit <54628184+shyammarjit@users.noreply.github.com> Date: Mon, 23 Oct 2023 23:13:43 +0530 Subject: [PATCH] fixed SDXL text encoder training bug #5016 (#5078) Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index caf04f4308..9a7b54a69c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1070,6 +1070,11 @@ def main(args): if args.train_text_encoder: text_encoder_one.train() text_encoder_two.train() + + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder_one.text_model.embeddings.requires_grad_(True) + text_encoder_two.text_model.embeddings.requires_grad_(True) + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype)