1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Qwen] avoid creating attention masks when there is no padding (#12987)

* avoid creating attention masks when there is no padding

* make fix-copies

* torch compile tests

* set all ones mask to none

* fix positional encoding from becoming > 4096

* fix from review

* slice freqs_cis to match the input sequence length

* keep only attenton masking change

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Kashif Rasul
2026-01-27 23:42:48 +01:00
committed by GitHub
parent 22ac6fae24
commit d54669a73e
10 changed files with 98 additions and 0 deletions

View File

@@ -262,6 +262,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -324,6 +324,9 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -305,6 +305,9 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -309,6 +309,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -321,6 +321,9 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs

View File

@@ -323,6 +323,9 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs

View File

@@ -305,6 +305,9 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -316,6 +316,9 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(

View File

@@ -328,6 +328,9 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):

View File

@@ -276,3 +276,74 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_no_mask_2 = model(**inputs_no_mask)
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_all_ones_2 = model(**inputs_all_ones)
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_with_padding_2 = model(**inputs_with_padding)
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))