diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 6f2bed9ce2..147e2b76e6 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -156,6 +156,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): mask_feature: (bool, defaults to `True`): If `True`, the function will mask the text embeddings. """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + if device is None: device = self._execution_device @@ -253,7 +255,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): negative_prompt_embeds = None # Perform additional masking. - if mask_feature and prompt_embeds is None and negative_prompt_embeds is None: + if mask_feature and not embeds_initially_provided: prompt_embeds = prompt_embeds.unsqueeze(1) masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) masked_prompt_embeds = masked_prompt_embeds.squeeze(1) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 70548092fe..a04f4e1a88 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -120,7 +120,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, - "mask_feature": False, } # set all optional components to None @@ -155,7 +154,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, - "mask_feature": False, } output_loaded = pipe_loaded(**inputs)[0]