From ef913010d4a337e1696cf50ff0cbb008e79c6b67 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 28 Jan 2026 15:44:12 +0530 Subject: [PATCH] [QwenImage] fix prompt isolation tests (#13042) * up * up * up * fix --- .../train_dreambooth_lora_qwen_image.py | 3 ++- .../pipelines/qwenimage/pipeline_qwenimage.py | 22 ++++++----------- .../pipeline_qwenimage_controlnet.py | 19 +++++---------- .../pipeline_qwenimage_edit_inpaint.py | 8 ------- .../qwenimage/pipeline_qwenimage_img2img.py | 24 +++++++------------ .../qwenimage/pipeline_qwenimage_inpaint.py | 23 +++++++----------- .../qwenimage/pipeline_qwenimage_layered.py | 13 +++++----- 7 files changed, 38 insertions(+), 74 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index ea9b137b0a..33a1054eff 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1467,7 +1467,8 @@ def main(args): else: num_repeat_elements = len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) - prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1) # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 21515df608..9938ae95fd 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -254,16 +254,17 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -310,15 +311,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 714ec04238..e0cb9924cd 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -321,11 +321,13 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -372,15 +374,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 75be62cf6d..bc397f357b 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -378,14 +378,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 2c9da7545e..522e1d2030 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -265,7 +265,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): return timesteps, num_inference_steps - t_start - # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -297,16 +297,17 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -357,15 +358,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 536a7984e8..960196bec1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -276,7 +276,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): return timesteps, num_inference_steps - t_start - # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -308,16 +308,17 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -372,14 +373,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 0a53a0ac77..ea455721e4 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -320,16 +320,17 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask