From 2c25b98c8ea74cfb5ec56ba49cc6edafef0b26af Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 24 Jul 2024 11:19:30 +0530 Subject: [PATCH] [AuraFlow] fix long prompt handling (#8937) fix --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 8b5a85b5ab..6a86b5cede 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -260,7 +260,6 @@ class AuraFlowPipeline(DiffusionPipeline): padding="max_length", return_tensors="pt", ) - text_inputs = {k: v.to(device) for k, v in text_inputs.items()} text_input_ids = text_inputs["input_ids"] untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids @@ -273,6 +272,7 @@ class AuraFlowPipeline(DiffusionPipeline): f" {max_length} tokens: {removed_text}" ) + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} prompt_embeds = self.text_encoder(**text_inputs)[0] prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask