1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to generate is more than 1 (#5752)

* does this fix things?

* attention mask use

* attention mask order

* better masking.

* add: tesrt

* remove mask_featur

* test

* debug

* fix: tests

* deprecate mask_feature

* add deprecation test

* add slow test

* add print statements to retrieve the assertion values.

* fix for the 1024 fast tes

* fix tesy

* fix the remaining

* Apply suggestions from code review

* more debug

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sayak Paul
2023-11-14 17:15:35 +05:30
committed by GitHub
parent 16d500455b
commit a5720e9e31
2 changed files with 138 additions and 43 deletions

View File

@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, Transformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
@@ -162,8 +163,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -189,10 +192,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
string.
clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None:
device = self._execution_device
@@ -229,13 +233,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" {max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds_attention_mask = attention_mask
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
@@ -250,8 +252,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -267,11 +269,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
add_special_tokens=True,
return_tensors="pt",
)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
@@ -284,23 +286,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
# Perform additional masking.
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
return masked_prompt_embeds, masked_negative_prompt_embeds
return prompt_embeds, negative_prompt_embeds
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -329,6 +321,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -365,6 +359,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
@@ -372,6 +372,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
@@ -579,14 +585,16 @@ class PixArtAlphaPipeline(DiffusionPipeline):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -630,9 +638,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -648,11 +659,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
Examples:
@@ -661,6 +671,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
@@ -669,7 +682,15 @@ class PixArtAlphaPipeline(DiffusionPipeline):
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
@@ -688,7 +709,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
@@ -696,11 +722,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -758,6 +786,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,

View File

@@ -111,13 +111,20 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False)
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
@@ -151,8 +158,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
@@ -211,13 +220,15 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
prompt_embeds, prompt_attn_mask, negative_prompt_embeds, neg_prompt_attn_mask = pipe.encode_prompt(prompt)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attn_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": neg_prompt_attn_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
@@ -252,8 +263,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attn_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": neg_prompt_attn_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
@@ -266,6 +279,40 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
def test_inference_with_multiple_images_per_prompt(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_images_per_prompt"] = 2
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (2, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_raises_warning_for_mask_feature(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs.update({"mask_feature": True})
with self.assertWarns(FutureWarning) as warning_ctx:
_ = pipe(**inputs).images
assert "mask_feature" in str(warning_ctx.warning)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
@@ -290,7 +337,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323])
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
@@ -307,7 +354,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266])
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
@@ -323,7 +370,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031])
expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
@@ -340,7 +387,26 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332])
expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
generator = torch.manual_seed(0)
no_res_bin_image = pipe(
prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False
).images
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)