mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix QwenImage txt_seq_lens handling (#12702)
* Fix QwenImage txt_seq_lens handling * formatting * formatting * remove txt_seq_lens and use bool mask * use compute_text_seq_len_from_mask * add seq_lens to dispatch_attention_fn * use joint_seq_lens * remove unused index_block * WIP: Remove seq_lens parameter and use mask-based approach - Remove seq_lens parameter from dispatch_attention_fn - Update varlen backends to extract seqlens from masks - Update QwenImage to pass 2D joint_attention_mask - Fix native backend to handle 2D boolean masks - Fix sage_varlen seqlens_q to match seqlens_k for self-attention Note: sage_varlen still producing black images, needs further investigation * fix formatting * undo sage changes * xformers support * hub fix * fix torch compile issues * fix tests * use _prepare_attn_mask_native * proper deprecation notice * add deprecate to txt_seq_lens * Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Only create the mask if there's actual padding * fix order of docstrings * Adds performance benchmarks and optimization details for QwenImage Enhances documentation with comprehensive performance insights for QwenImage pipeline: * rope_text_seq_len = text_seq_len * rename to max_txt_seq_len * removed deprecated args * undo unrelated change * Updates QwenImage performance documentation Removes detailed attention backend benchmarks and simplifies torch.compile performance description Focuses on key performance improvement with torch.compile, highlighting the specific speedup from 4.70s to 1.93s on an A100 GPU Streamlines the documentation to provide more concise and actionable performance insights * Updates deprecation warnings for txt_seq_lens parameter Extends deprecation timeline for txt_seq_lens from version 0.37.0 to 0.39.0 across multiple Qwen image-related models Adds a new unit test to verify the deprecation warning behavior for the txt_seq_lens parameter * fix compile * formatting * fix compile tests * rename helper * remove duplicate * smaller values * removed * use torch.cond for torch compile * Construct joint attention mask once * test different backends * construct joint attention mask once to avoid reconstructing in every block * Update src/diffusers/models/attention_dispatch.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * formatting * raising an error from the EditPlus pipeline when batch_size > 1 --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: cdutr <dutra_carlos@hotmail.com>
This commit is contained in:
@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### torch.compile
|
||||
|
||||
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
|
||||
# First call triggers compilation (~7s overhead)
|
||||
# Subsequent calls run at ~2.4x faster
|
||||
image = pipe("a cat", num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
### Batched Inference with Variable-Length Prompts
|
||||
|
||||
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
|
||||
|
||||
```python
|
||||
# CFG with different prompt lengths works correctly
|
||||
image = pipe(
|
||||
prompt="A cat",
|
||||
negative_prompt="blurry, low quality, distorted",
|
||||
true_cfg_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -1513,14 +1513,12 @@ def main(args):
|
||||
height=model_input.shape[3],
|
||||
width=model_input.shape[4],
|
||||
)
|
||||
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
timestep=timesteps / 1000,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = QwenImagePipeline._unpack_latents(
|
||||
|
||||
@@ -1945,6 +1945,43 @@ def _native_flex_attention(
|
||||
return out
|
||||
|
||||
|
||||
def _prepare_additive_attn_mask(
|
||||
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
|
||||
|
||||
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
|
||||
|
||||
Args:
|
||||
attn_mask: 2D tensor [batch_size, seq_len_k]
|
||||
- Boolean: True means attend, False means mask out
|
||||
- Additive: 0.0 means attend, -inf means mask out
|
||||
target_dtype: The dtype to convert the mask to (usually query.dtype)
|
||||
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
|
||||
|
||||
Returns:
|
||||
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
|
||||
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
|
||||
"""
|
||||
# Check if the mask is boolean or already additive
|
||||
if attn_mask.dtype == torch.bool:
|
||||
# Convert boolean to additive: True -> 0.0, False -> -inf
|
||||
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
|
||||
# Convert to target dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
else:
|
||||
# Already additive mask - just ensure correct dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
|
||||
# Optionally reshape to 4D for broadcasting in attention mechanisms
|
||||
if reshape_4d:
|
||||
batch_size, seq_len_k = attn_mask.shape
|
||||
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
@@ -1964,6 +2001,19 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
# Reshape 2D mask to 4D for SDPA
|
||||
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
|
||||
# SDPA handles both boolean and additive masks correctly
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -2530,10 +2580,34 @@ def _xformers_attention(
|
||||
attn_mask = xops.LowerTriangularMask()
|
||||
elif attn_mask is not None:
|
||||
if attn_mask.ndim == 2:
|
||||
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
||||
# Convert 2D mask to 4D for xformers
|
||||
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
|
||||
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
|
||||
# Need memory alignment - create larger tensor and slice for alignment
|
||||
original_seq_len = attn_mask.size(1)
|
||||
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
|
||||
|
||||
# Create aligned 4D tensor and slice to ensure proper memory layout
|
||||
aligned_mask = torch.zeros(
|
||||
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
# Convert to 4D additive mask (handles both boolean and additive inputs)
|
||||
mask_additive = _prepare_additive_attn_mask(
|
||||
attn_mask, target_dtype=query.dtype
|
||||
) # [batch, 1, 1, seq_len_k]
|
||||
# Broadcast to [batch, heads, seq_q, seq_len_k]
|
||||
aligned_mask[:, :, :, :original_seq_len] = mask_additive
|
||||
# Mask out the padding (already -inf from zeros -> where with default)
|
||||
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
|
||||
|
||||
# Slice to actual size with proper alignment
|
||||
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
|
||||
elif attn_mask.ndim != 4:
|
||||
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
elif attn_mask.ndim == 4:
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
|
||||
if enable_gqa:
|
||||
if num_heads_q % num_heads_kv != 0:
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import (
|
||||
QwenImageTransformerBlock,
|
||||
QwenTimestepProjEmbeddings,
|
||||
RMSNorm,
|
||||
compute_text_seq_len_from_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -136,7 +137,7 @@ class QwenImageControlNetModel(
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`QwenImageControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
@@ -147,24 +148,39 @@ class QwenImageControlNetModel(
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*):
|
||||
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
|
||||
length.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
|
||||
the first element is the controlnet block samples.
|
||||
"""
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
|
||||
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
|
||||
"and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
@@ -186,32 +202,47 @@ class QwenImageControlNetModel(
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
@@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
|
||||
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
|
||||
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
@@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -142,6 +142,32 @@ def apply_rotary_emb_qwen(
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
|
||||
) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
|
||||
"""
|
||||
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
|
||||
if encoder_hidden_states_mask is None:
|
||||
return text_seq_len, None, None
|
||||
|
||||
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
|
||||
raise ValueError(
|
||||
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
|
||||
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
|
||||
)
|
||||
|
||||
if encoder_hidden_states_mask.dtype != torch.bool:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
|
||||
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
@@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
txt_seq_lens: List[int],
|
||||
device: torch.device,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
device: torch.device = None,
|
||||
max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
txt_seq_lens (`List[int]`):
|
||||
A list of integers of length batch_size representing the length of each text prompt.
|
||||
device: (`torch.device`):
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `max_txt_seq_len` instead. "
|
||||
"The new parameter accepts a single int or tensor value representing the maximum text sequence length.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if max_txt_seq_len is None:
|
||||
# Use max of txt_seq_lens for backward compatibility
|
||||
max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens
|
||||
|
||||
if max_txt_seq_len is None:
|
||||
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
|
||||
|
||||
# Validate batch inference with variable-sized images
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if all instances have the same size
|
||||
first_fhw = video_fhw[0]
|
||||
if not all(fhw == first_fhw for fhw in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedRope. "
|
||||
"All images in the batch should have the same dimensions (frame, height, width). "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module):
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
max_txt_seq_len: Union[int, torch.Tensor],
|
||||
device: torch.device = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
|
||||
structures.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Validate batch inference with variable-sized images
|
||||
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if this is batch inference (list of layer lists/tuples)
|
||||
first_entry = video_fhw[0]
|
||||
if not all(entry == first_entry for entry in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. "
|
||||
"All images in the batch should have the same layer structure. "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_condition_freqs(frame, height, width, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Compute joint attention
|
||||
joint_hidden_states = dispatch_attention_fn(
|
||||
joint_query,
|
||||
joint_key,
|
||||
@@ -762,14 +848,25 @@ class QwenImageTransformer2DModel(
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
||||
Mask of the input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
|
||||
used to compute RoPE sequence length.
|
||||
guidance (`torch.Tensor`, *optional*):
|
||||
Guidance tensor for conditional generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_block_samples (*optional*):
|
||||
ControlNet block samples to add to the transformer blocks.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
@@ -778,6 +875,15 @@ class QwenImageTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `encoder_hidden_states_mask` instead. "
|
||||
"The mask-based approach is more flexible and supports variable-length sequences.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -810,6 +916,11 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
@@ -819,7 +930,17 @@ class QwenImageTransformer2DModel(
|
||||
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
|
||||
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -827,10 +948,10 @@ class QwenImageTransformer2DModel(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
block_attention_kwargs,
|
||||
modulate_index,
|
||||
)
|
||||
|
||||
@@ -838,10 +959,10 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
modulate_index=modulate_index,
|
||||
)
|
||||
|
||||
|
||||
@@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
)
|
||||
]
|
||||
] * block_state.batch_size
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
|
||||
"It should contain prompt_embeds/negative_prompt_embeds."
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
img_shapes=block_state.img_shapes,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
|
||||
txt_seq_lens=block_state.txt_seq_lens,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
@@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
|
||||
@@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# QwenImageEditPlusPipeline does not currently support batch_size > 1
|
||||
if batch_size > 1:
|
||||
raise ValueError(
|
||||
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
|
||||
"Please process prompts one at a time."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
@@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
|
||||
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
|
||||
|
||||
@@ -15,10 +15,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
@@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
]
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
@@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
Reference in New Issue
Block a user