diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 1f39cc168c..6f2bed9ce2 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -253,7 +253,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): negative_prompt_embeds = None # Perform additional masking. - if mask_feature: + if mask_feature and prompt_embeds is None and negative_prompt_embeds is None: prompt_embeds = prompt_embeds.unsqueeze(1) masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) masked_prompt_embeds = masked_prompt_embeds.squeeze(1) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 1797f7e0fe..2e033896d9 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -181,11 +181,76 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_inference_with_embeddings_and_multiple_images(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt": None, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "num_images_per_prompt": 2, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt": None, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "num_images_per_prompt": 2, + } + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) -# TODO: needs to be updated. @slow @require_torch_gpu class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):