1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SDXL LoRA] fix batch size lora (#4509)

fix batch size lora
This commit is contained in:
Patrick von Platen
2023-08-07 13:27:13 +02:00
committed by GitHub
parent b2456717e6
commit dff5ff35a9

View File

@@ -1103,11 +1103,11 @@ def main(args):
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
}
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
else:
@@ -1119,9 +1119,9 @@ def main(args):
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
# Get the target for loss depending on the prediction type