From a5720e9e3124753c85b2260dec5f39d75ce18245 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Nov 2023 17:15:35 +0530 Subject: [PATCH] [PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to generate is more than 1 (#5752) * does this fix things? * attention mask use * attention mask order * better masking. * add: tesrt * remove mask_featur * test * debug * fix: tests * deprecate mask_feature * add deprecation test * add slow test * add print statements to retrieve the assertion values. * fix for the 1024 fast tes * fix tesy * fix the remaining * Apply suggestions from code review * more debug --------- Co-authored-by: Patrick von Platen --- .../pixart_alpha/pipeline_pixart_alpha.py | 103 +++++++++++------- tests/pipelines/pixart/test_pixart.py | 78 ++++++++++++- 2 files changed, 138 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index c3f667ba16..f4e61bdc94 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -27,6 +27,7 @@ from ...models import AutoencoderKL, Transformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, + deprecate, is_bs4_available, is_ftfy_available, logging, @@ -162,8 +163,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): device: Optional[torch.device] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, clean_caption: bool = False, - mask_feature: bool = True, + **kwargs, ): r""" Encodes the prompt into text encoder hidden states. @@ -189,10 +192,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): string. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. - mask_feature: (bool, defaults to `True`): - If `True`, the function will mask the text embeddings. """ - embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) if device is None: device = self._execution_device @@ -229,13 +233,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): f" {max_length} tokens: {removed_text}" ) - attention_mask = text_inputs.attention_mask.to(device) - prompt_embeds_attention_mask = attention_mask + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] - else: - prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) if self.text_encoder is not None: dtype = self.text_encoder.dtype @@ -250,8 +252,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) - prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -267,11 +269,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): add_special_tokens=True, return_tensors="pt", ) - attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask ) negative_prompt_embeds = negative_prompt_embeds[0] @@ -284,23 +286,13 @@ class PixArtAlphaPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None + negative_prompt_attention_mask = None - # Perform additional masking. - if mask_feature and not embeds_initially_provided: - prompt_embeds = prompt_embeds.unsqueeze(1) - masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) - masked_prompt_embeds = masked_prompt_embeds.squeeze(1) - masked_negative_prompt_embeds = ( - negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None - ) - return masked_prompt_embeds, masked_negative_prompt_embeds - - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -329,6 +321,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): callback_steps, prompt_embeds=None, negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -365,6 +359,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( @@ -372,6 +372,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): @@ -579,14 +585,16 @@ class PixArtAlphaPipeline(DiffusionPipeline): generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, - mask_feature: bool = True, use_resolution_binning: bool = True, + **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -630,9 +638,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -648,11 +659,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. - mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. - use_resolution_binning: - (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the - closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, - they are resized back to the requested resolution. Useful for generating non-square images. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. Examples: @@ -661,6 +671,9 @@ class PixArtAlphaPipeline(DiffusionPipeline): If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor @@ -669,7 +682,15 @@ class PixArtAlphaPipeline(DiffusionPipeline): height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) self.check_inputs( - prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, ) # 2. Default height and width to transformer @@ -688,7 +709,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -696,11 +722,13 @@ class PixArtAlphaPipeline(DiffusionPipeline): device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, - mask_feature=mask_feature, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -758,6 +786,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, return_dict=False, diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 1fb2560b29..b2806a5c1c 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -111,13 +111,20 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): num_inference_steps = inputs["num_inference_steps"] output_type = inputs["output_type"] - prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False) + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = pipe.encode_prompt(prompt) # inputs with prompt converted to embeddings inputs = { "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, "negative_prompt": None, "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, @@ -151,8 +158,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # inputs with prompt converted to embeddings inputs = { "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, "negative_prompt": None, "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, @@ -211,13 +220,15 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): num_inference_steps = inputs["num_inference_steps"] output_type = inputs["output_type"] - prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt) + prompt_embeds, prompt_attn_mask, negative_prompt_embeds, neg_prompt_attn_mask = pipe.encode_prompt(prompt) # inputs with prompt converted to embeddings inputs = { "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attn_mask, "negative_prompt": None, "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": neg_prompt_attn_mask, "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, @@ -252,8 +263,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # inputs with prompt converted to embeddings inputs = { "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attn_mask, "negative_prompt": None, "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": neg_prompt_attn_mask, "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, @@ -266,6 +279,40 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1e-4) + def test_inference_with_multiple_images_per_prompt(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_images_per_prompt"] = 2 + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (2, 8, 8, 3)) + expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675]) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_raises_warning_for_mask_feature(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs.update({"mask_feature": True}) + + with self.assertWarns(FutureWarning) as warning_ctx: + _ = pipe(**inputs).images + + assert "mask_feature" in str(warning_ctx.warning) + def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) @@ -290,7 +337,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323]) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -307,7 +354,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266]) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -323,7 +370,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031]) + expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) @@ -340,7 +387,26 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332]) + expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + + def test_pixart_1024_without_resolution_binning(self): + generator = torch.manual_seed(0) + + pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + + prompt = "A small cactus with a happy face in the Sahara desert." + + image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images + image_slice = image[0, -3:, -3:, -1] + + generator = torch.manual_seed(0) + no_res_bin_image = pipe( + prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False + ).images + no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1] + + assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)