mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
b2456717e6
commit
dff5ff35a9
@@ -1103,11 +1103,11 @@ def main(args):
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
|
||||
}
|
||||
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
@@ -1119,9 +1119,9 @@ def main(args):
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
|
||||
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
|
||||
Reference in New Issue
Block a user