1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[PixArt-Alpha] fix mask feature condition. (#5695)

* fix mask feature condition.

* debug

* remove identical test

* set correct

* Empty-Commit
This commit is contained in:
Sayak Paul
2023-11-08 17:42:46 +05:30
committed by GitHub
parent c803a8f8c0
commit 78be400761
2 changed files with 3 additions and 3 deletions

View File

@@ -156,6 +156,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if device is None:
device = self._execution_device
@@ -253,7 +255,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds = None
# Perform additional masking.
if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)

View File

@@ -120,7 +120,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
# set all optional components to None
@@ -155,7 +154,6 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
output_loaded = pipe_loaded(**inputs)[0]