From dff5ff35a9be8f2809134d08e7e41711cb9f34ed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 7 Aug 2023 13:27:13 +0200 Subject: [PATCH] [SDXL LoRA] fix batch size lora (#4509) fix batch size lora --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a2b6e4a382..6f99dbc64d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1103,11 +1103,11 @@ def main(args): "time_ids": add_time_ids.repeat(elems_to_repeat, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), } - prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = unet( noisy_model_input, timesteps, - prompt_embeds, + prompt_embeds_input, added_cond_kwargs=unet_added_conditions, ).sample else: @@ -1119,9 +1119,9 @@ def main(args): text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) - prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions + noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions ).sample # Get the target for loss depending on the prediction type