From dad5cb55e6ade24fb397525eb023ad4eba37019d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 09:15:09 +0100 Subject: [PATCH 1/4] Fix QwenImage txt_seq_lens handling (#12702) * Fix QwenImage txt_seq_lens handling * formatting * formatting * remove txt_seq_lens and use bool mask * use compute_text_seq_len_from_mask * add seq_lens to dispatch_attention_fn * use joint_seq_lens * remove unused index_block * WIP: Remove seq_lens parameter and use mask-based approach - Remove seq_lens parameter from dispatch_attention_fn - Update varlen backends to extract seqlens from masks - Update QwenImage to pass 2D joint_attention_mask - Fix native backend to handle 2D boolean masks - Fix sage_varlen seqlens_q to match seqlens_k for self-attention Note: sage_varlen still producing black images, needs further investigation * fix formatting * undo sage changes * xformers support * hub fix * fix torch compile issues * fix tests * use _prepare_attn_mask_native * proper deprecation notice * add deprecate to txt_seq_lens * Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu * Only create the mask if there's actual padding * fix order of docstrings * Adds performance benchmarks and optimization details for QwenImage Enhances documentation with comprehensive performance insights for QwenImage pipeline: * rope_text_seq_len = text_seq_len * rename to max_txt_seq_len * removed deprecated args * undo unrelated change * Updates QwenImage performance documentation Removes detailed attention backend benchmarks and simplifies torch.compile performance description Focuses on key performance improvement with torch.compile, highlighting the specific speedup from 4.70s to 1.93s on an A100 GPU Streamlines the documentation to provide more concise and actionable performance insights * Updates deprecation warnings for txt_seq_lens parameter Extends deprecation timeline for txt_seq_lens from version 0.37.0 to 0.39.0 across multiple Qwen image-related models Adds a new unit test to verify the deprecation warning behavior for the txt_seq_lens parameter * fix compile * formatting * fix compile tests * rename helper * remove duplicate * smaller values * removed * use torch.cond for torch compile * Construct joint attention mask once * test different backends * construct joint attention mask once to avoid reconstructing in every block * Update src/diffusers/models/attention_dispatch.py Co-authored-by: YiYi Xu * formatting * raising an error from the EditPlus pipeline when batch_size > 1 --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu Co-authored-by: cdutr --- docs/source/en/api/pipelines/qwenimage.md | 38 +++- .../train_dreambooth_lora_qwen_image.py | 2 - src/diffusers/models/attention_dispatch.py | 78 ++++++- .../controlnets/controlnet_qwenimage.py | 71 ++++-- .../transformers/transformer_qwenimage.py | 203 ++++++++++++++---- .../qwenimage/before_denoise.py | 41 ---- .../modular_pipelines/qwenimage/denoise.py | 11 +- .../pipelines/qwenimage/pipeline_qwenimage.py | 7 - .../pipeline_qwenimage_controlnet.py | 3 - .../pipeline_qwenimage_controlnet_inpaint.py | 3 - .../qwenimage/pipeline_qwenimage_edit.py | 7 - .../pipeline_qwenimage_edit_inpaint.py | 7 - .../qwenimage/pipeline_qwenimage_edit_plus.py | 14 +- .../qwenimage/pipeline_qwenimage_img2img.py | 7 - .../qwenimage/pipeline_qwenimage_inpaint.py | 7 - .../qwenimage/pipeline_qwenimage_layered.py | 8 +- .../test_models_transformer_qwenimage.py | 178 ++++++++++++++- 17 files changed, 513 insertions(+), 172 deletions(-) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index 1dcf3f944f..ee3dd3b28e 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### torch.compile + +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 53b01bf0cf..ea9b137b0a 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,14 +1513,12 @@ def main(args): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index bdb9c4861a..9fb1e94a35 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1945,6 +1945,43 @@ def _native_flex_attention( return out +def _prepare_additive_attn_mask( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. + + Args: + attn_mask: 2D tensor [batch_size, seq_len_k] + - Boolean: True means attend, False means mask out + - Additive: 0.0 means attend, -inf means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Check if the mask is boolean or already additive + if attn_mask.dtype == torch.bool: + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + else: + # Already additive mask - just ensure correct dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -1964,6 +2001,19 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Reshape 2D mask to 4D for SDPA + # SDPA accepts both boolean masks (torch.bool) and additive masks (float) + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] + # SDPA handles both boolean and additive masks correctly + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2530,10 +2580,34 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D mask to 4D for xformers + # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) + # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Convert to 4D additive mask (handles both boolean and additive inputs) + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 8697127178..fa374285ee 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import ( QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -136,7 +137,7 @@ class QwenImageControlNetModel( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -147,24 +148,39 @@ class QwenImageControlNetModel( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -186,32 +202,47 @@ class QwenImageControlNetModel( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Construct joint attention mask once to avoid reconstructing in every block + block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, + block_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 1229bab169..a8c98201d9 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -142,6 +142,32 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], - device: torch.device, + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. - device: (`torch.device`): + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", + standard_warn=False, + ) + if max_txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module): else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width) - video_freq = video_freq.to(device) + video_freq = self._compute_condition_freqs(frame, height, width, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module): max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module): return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width): + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0: joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, joint_key, @@ -762,14 +848,25 @@ class QwenImageTransformer2DModel( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -778,6 +875,15 @@ class QwenImageTransformer2DModel( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -810,6 +916,11 @@ class QwenImageTransformer2DModel( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -819,7 +930,17 @@ class QwenImageTransformer2DModel( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + # Construct joint attention mask once to avoid reconstructing in every block + # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -827,10 +948,10 @@ class QwenImageTransformer2DModel( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, - attention_kwargs, + block_attention_kwargs, modulate_index, ) @@ -838,10 +959,10 @@ class QwenImageTransformer2DModel( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, modulate_index=modulate_index, ) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index e14164229c..d9c8cbb01d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): ) ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) self.set_block_state(state, block_state) @@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index eb1e5a341c..d6bcb4a94f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + "It should contain prompt_embeds/negative_prompt_embeds." ), ), ] @@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks): getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) @@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks): getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b9..bc3ce84e10 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93..ce6fc974a5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab5..77d78a5ca7 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8..dd723460a5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa..cf467203a9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf16..257e2d846c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): else: batch_size = prompt_embeds.shape[0] + # QwenImageEditPlusPipeline does not currently support batch_size > 1 + if batch_size > 1: + raise ValueError( + f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " + "Please process prompts one at a time." + ) + device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): @@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016..e0b41b8b87 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2..83f02539b1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 7bb12c26ba..53d2c169ee 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as latents = latents[:, :, 1:] # remove the first frame as it is the orgin input - latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503..384954dfba 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) + + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) + + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask + with torch.no_grad(): + output = model(**inputs) + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(int(per_sample_len.max().item()), 5) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 3, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) + "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + additional_t_cond=addition_t_cond, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() From 29a930a142bba88ce30eba8e93e512a3e9bdc49c Mon Sep 17 00:00:00 2001 From: Leo Jiang Date: Mon, 12 Jan 2026 07:37:02 -0700 Subject: [PATCH 2/4] Bugfix for flux2 img2img2 prediction (#12855) * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 Co-authored-by: tcaimm <93749364+tcaimm@users.noreply.github.com> --------- Co-authored-by: tcaimm <93749364+tcaimm@users.noreply.github.com> --- .../train_dreambooth_lora_flux2_img2img.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 419821e8a8..5af0906642 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1695,9 +1695,13 @@ def main(args): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1724,6 +1728,9 @@ def main(args): packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) @@ -1742,7 +1749,8 @@ def main(args): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) From f1a93c765f4b292352e26fb10670373b8e5837f7 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:01:58 -0800 Subject: [PATCH 3/4] Add Flag to `PeftLoraLoaderMixinTests` to Enable/Disable Text Encoder LoRA Tests (#12962) * Improve incorrect LoRA format error message * Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests * Apply changes to LTX2LoraTests * Further improve incorrect LoRA format error msg following review --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/lora_pipeline.py | 40 ++++++++++----------- tests/lora/test_lora_layers_auraflow.py | 22 ++---------- tests/lora/test_lora_layers_cogvideox.py | 22 ++---------- tests/lora/test_lora_layers_cogview4.py | 22 ++---------- tests/lora/test_lora_layers_flux2.py | 22 ++---------- tests/lora/test_lora_layers_hunyuanvideo.py | 22 ++---------- tests/lora/test_lora_layers_ltx2.py | 26 ++------------ tests/lora/test_lora_layers_ltx_video.py | 22 ++---------- tests/lora/test_lora_layers_lumina2.py | 22 ++---------- tests/lora/test_lora_layers_mochi.py | 22 ++---------- tests/lora/test_lora_layers_qwenimage.py | 22 ++---------- tests/lora/test_lora_layers_sana.py | 22 ++---------- tests/lora/test_lora_layers_wan.py | 22 ++---------- tests/lora/test_lora_layers_wanvace.py | 22 ++---------- tests/lora/test_lora_layers_z_image.py | 22 ++---------- tests/lora/utils.py | 19 ++++++++++ 16 files changed, 67 insertions(+), 304 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5fc650a80d..24d1fd7b93 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_peft_state_dict = { k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") @@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56..78ef4ce151 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2..7bd54b77ca 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb63..e8ee6e7a7d 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 4ae189aceb..d970b7d784 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Flux2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a..e59bc5662f 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder_2", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py index 886ae70b7d..0a4b14454f 100644 --- a/tests/lora/test_lora_layers_ltx2.py +++ b/tests/lora/test_lora_layers_ltx2.py @@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 5, 32, 32, 3) @@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTX2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_save_pretrained_with_text_lora(self): - pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e51..095e5b577c 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33..da032229a7 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 4, 4, 3) @@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db..ee82541129 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 7, 16, 16, 3) @@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20..73fd026a67 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ) denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index a860b7b44f..97bf5cbba9 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference_denoiser(self): return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b41..5ae16ab4b9 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9..c8acaea9be 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index 35d1389d96..8432ea56a6 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in ZImage.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5fae6cac0a..efa49b9f48 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests: tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + supports_text_encoder_loras = True unet_kwargs = None transformer_cls = None @@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA. """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests: with different ranks and some adapters removed and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, _, _ = self.get_dummy_components() # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) From 9d68742214b9491718487b63f25d6e8236ed56ea Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri <66717082+Bissmella@users.noreply.github.com> Date: Tue, 13 Jan 2026 04:46:51 +0100 Subject: [PATCH 4/4] Add Unified Sequence Parallel attention (#12693) * initial scheme of unified-sp * initial all_to_all_double * bug fixes, added cmnts * unified attention prototype done * remove raising value error in contextParallelConfig to enable unified attention * bug fix * feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention * bug fix, lse calculation, testing bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix * addressing comments * sequence parallelsim bug fixes * code format fixes * Apply style fixes * code formatting fix * added unified attention docs and removed test file * Apply style fixes * tip for unified attention in docs at distributed_inference.md Co-authored-by: Sayak Paul * Update distributed_inference.md, adding benchmarks Co-authored-by: Sayak Paul * Update docs/source/en/training/distributed_inference.md Co-authored-by: Sayak Paul * function name fix * fixed benchmark in docs --------- Co-authored-by: KarthikSundar2002 Co-authored-by: github-actions[bot] Co-authored-by: Sayak Paul --- .../en/training/distributed_inference.md | 28 +++ src/diffusers/models/_modeling_parallel.py | 4 - src/diffusers/models/attention_dispatch.py | 188 +++++++++++++++++- 3 files changed, 212 insertions(+), 8 deletions(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index b7cd0e20f4..b7255f74af 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, ).to(device) ``` +### Unified Attention + +[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout. + +This hybrid approach leverages the strengths of both methods: +- **Ulysses Attention** efficiently parallelizes across attention heads +- **Ring Attention** handles very long sequences with minimal memory overhead +- Together, they enable 2D parallelization across both heads and sequence dimensions + +[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping). +Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) +``` + +> [!TIP] +> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). + +We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | +|--------------------|------------------|-------------|------------------| +| ulysses | 6670.789 | 7.50 | 33.85 | +| ring | 13076.492 | 3.82 | 56.02 | +| unified_balanced | 11068.705 | 4.52 | 33.85 | + +From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention. diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c7..1c7703a13c 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,6 @@ class ContextParallelConfig: ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 9fb1e94a35..f4ec497038 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + """ + Perform dimension sharding / reassembly across processes using _all_to_all_single. + + This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or + head dimension flexibly by accepting scatter_idx and gather_idx. + + Args: + x (torch.Tensor): + Input tensor. Expected shapes: + - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) + - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) + scatter_idx (int) : + Dimension along which the tensor is partitioned before all-to-all. + gather_idx (int): + Dimension along which the output is reassembled after all-to-all. + group : + Distributed process group for the Ulysses group. + + Returns: + torch.Tensor: Tensor with globally exchanged dimensions. + - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) + - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) + """ + group_world_size = torch.distributed.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + # dimension and scatters head dimension + batch_size, seq_len_local, num_heads, head_dim = x.shape + seq_len = seq_len_local * group_world_size + num_heads_local = num_heads // group_world_size + + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D + x_temp = ( + x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) + .transpose(0, 2) + .contiguous() + ) + + if group_world_size > 1: + out = _all_to_all_single(x_temp, group=group) + else: + out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D + out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() + out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) + return out + elif scatter_idx == 1 and gather_idx == 2: + # Used after ulysses sequence parallel in unified SP. gathers the head dimension + # scatters back the sequence dimension. + batch_size, seq_len, num_heads_local, head_dim = x.shape + num_heads = num_heads_local * group_world_size + seq_len_local = seq_len // group_world_size + + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) + + if group_world_size > 1: + output = _all_to_all_single(x_temp, group) + else: + output = x_temp + output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() + output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) + return output + else: + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") + + +class SeqAllToAllDim(torch.autograd.Function): + """ + all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange + for more info. + """ + + @staticmethod + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) + + @staticmethod + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function): out = out.to(torch.float32) lse = lse.to(torch.float32) - lse = lse.unsqueeze(-1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) @@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function): grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function): x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None + + +def _templated_unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + scatter_idx: int = 2, + gather_idx: int = 1, +): + """ + Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 + """ + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # lse is of shape (B, S, H_LOCAL, 1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( @@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if ( + _parallel_config.context_parallel_config.ring_degree > 1 + and _parallel_config.context_parallel_config.ulysses_degree > 1 + ): + return _templated_unified_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key,