1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[AuraFlow] fix long prompt handling (#8937)

fix
This commit is contained in:
Sayak Paul
2024-07-24 11:19:30 +05:30
committed by GitHub
parent 93983b6780
commit 2c25b98c8e

View File

@@ -260,7 +260,6 @@ class AuraFlowPipeline(DiffusionPipeline):
padding="max_length",
return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
text_input_ids = text_inputs["input_ids"]
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
@@ -273,6 +272,7 @@ class AuraFlowPipeline(DiffusionPipeline):
f" {max_length} tokens: {removed_text}"
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
prompt_embeds = self.text_encoder(**text_inputs)[0]
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask