1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into feat/distill-ltx2

This commit is contained in:
Sayak Paul
2026-01-13 11:04:51 +05:30
committed by GitHub
36 changed files with 802 additions and 486 deletions

View File

@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```
## Performance
### torch.compile
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
```python
import torch
from diffusers import QwenImagePipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer = torch.compile(pipe.transformer)
# First call triggers compilation (~7s overhead)
# Subsequent calls run at ~2.4x faster
image = pipe("a cat", num_inference_steps=50).images[0]
```
### Batched Inference with Variable-Length Prompts
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
```python
# CFG with different prompt lengths works correctly
image = pipe(
prompt="A cat",
negative_prompt="blurry, low quality, distorted",
true_cfg_scale=3.5,
num_inference_steps=50,
).images[0]
```
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
## QwenImagePipeline
[[autodoc]] QwenImagePipeline

View File

@@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```
### Unified Attention
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
This hybrid approach leverages the strengths of both methods:
- **Ulysses Attention** efficiently parallelizes across attention heads
- **Ring Attention** handles very long sequences with minimal memory overhead
- Together, they enable 2D parallelization across both heads and sequence dimensions
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
```
> [!TIP]
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|--------------------|------------------|-------------|------------------|
| ulysses | 6670.789 | 7.50 | 33.85 |
| ring | 13076.492 | 3.82 | 56.02 |
| unified_balanced | 11068.705 | 4.52 | 33.85 |
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.

View File

@@ -1695,9 +1695,13 @@ def main(args):
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1724,6 +1728,9 @@ def main(args):
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
orig_input_shape = packed_noisy_model_input.shape
orig_input_ids_shape = model_input_ids.shape
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1742,7 +1749,8 @@ def main(args):
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = model_pred[:, : orig_input_shape[1], :]
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)

View File

@@ -1513,14 +1513,12 @@ def main(args):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(

View File

@@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_lora_state_dict = {
k: state_dict.get(k)
@@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
@@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,

View File

@@ -90,10 +90,6 @@ class ContextParallelConfig:
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

View File

@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
return x
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
"""
Perform dimension sharding / reassembly across processes using _all_to_all_single.
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
head dimension flexibly by accepting scatter_idx and gather_idx.
Args:
x (torch.Tensor):
Input tensor. Expected shapes:
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
scatter_idx (int) :
Dimension along which the tensor is partitioned before all-to-all.
gather_idx (int):
Dimension along which the output is reassembled after all-to-all.
group :
Distributed process group for the Ulysses group.
Returns:
torch.Tensor: Tensor with globally exchanged dimensions.
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
"""
group_world_size = torch.distributed.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
# dimension and scatters head dimension
batch_size, seq_len_local, num_heads, head_dim = x.shape
seq_len = seq_len_local * group_world_size
num_heads_local = num_heads // group_world_size
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
x_temp = (
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
.transpose(0, 2)
.contiguous()
)
if group_world_size > 1:
out = _all_to_all_single(x_temp, group=group)
else:
out = x_temp
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
return out
elif scatter_idx == 1 and gather_idx == 2:
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
# scatters back the sequence dimension.
batch_size, seq_len, num_heads_local, head_dim = x.shape
num_heads = num_heads_local * group_world_size
seq_len_local = seq_len // group_world_size
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
x_temp = (
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
.permute(1, 3, 2, 0, 4)
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
)
if group_world_size > 1:
output = _all_to_all_single(x_temp, group)
else:
output = x_temp
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
return output
else:
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
class SeqAllToAllDim(torch.autograd.Function):
"""
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
for more info.
"""
@staticmethod
def forward(ctx, group, input, scatter_id=2, gather_id=1):
ctx.group = group
ctx.scatter_id = scatter_id
ctx.gather_id = gather_id
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
@staticmethod
def backward(ctx, grad_outputs):
grad_input = SeqAllToAllDim.apply(
ctx.group,
grad_outputs,
ctx.gather_id, # reversed
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function):
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
def _templated_unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
scatter_idx: int = 2,
gather_idx: int = 1,
):
"""
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
"""
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
ulysses_group = ulysses_mesh.get_group()
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
out = TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if return_lse:
context_layer, lse, *_ = out
else:
context_layer = out
# context_layer is of shape (B, S, H_LOCAL, D)
output = SeqAllToAllDim.apply(
ulysses_group,
context_layer,
gather_idx,
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
return (output, lse)
return output
def _templated_context_parallel_attention(
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if _parallel_config.context_parallel_config.ring_degree > 1:
if (
_parallel_config.context_parallel_config.ring_degree > 1
and _parallel_config.context_parallel_config.ulysses_degree > 1
):
return _templated_unified_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
@@ -1945,6 +2125,43 @@ def _native_flex_attention(
return out
def _prepare_additive_attn_mask(
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
) -> torch.Tensor:
"""
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
Args:
attn_mask: 2D tensor [batch_size, seq_len_k]
- Boolean: True means attend, False means mask out
- Additive: 0.0 means attend, -inf means mask out
target_dtype: The dtype to convert the mask to (usually query.dtype)
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
Returns:
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
"""
# Check if the mask is boolean or already additive
if attn_mask.dtype == torch.bool:
# Convert boolean to additive: True -> 0.0, False -> -inf
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
# Convert to target dtype
attn_mask = attn_mask.to(dtype=target_dtype)
else:
# Already additive mask - just ensure correct dtype
attn_mask = attn_mask.to(dtype=target_dtype)
# Optionally reshape to 4D for broadcasting in attention mechanisms
if reshape_4d:
batch_size, seq_len_k = attn_mask.shape
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
return attn_mask
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
@@ -1964,6 +2181,19 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
# Reshape 2D mask to 4D for SDPA
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
# SDPA handles both boolean and additive masks correctly
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2530,10 +2760,34 @@ def _xformers_attention(
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
# Convert 2D mask to 4D for xformers
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
# Need memory alignment - create larger tensor and slice for alignment
original_seq_len = attn_mask.size(1)
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
# Create aligned 4D tensor and slice to ensure proper memory layout
aligned_mask = torch.zeros(
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
dtype=query.dtype,
device=query.device,
)
# Convert to 4D additive mask (handles both boolean and additive inputs)
mask_additive = _prepare_additive_attn_mask(
attn_mask, target_dtype=query.dtype
) # [batch, 1, 1, seq_len_k]
# Broadcast to [batch, heads, seq_q, seq_len_k]
aligned_mask[:, :, :, :original_seq_len] = mask_additive
# Mask out the padding (already -inf from zeros -> where with default)
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
# Slice to actual size with proper alignment
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
elif attn_mask.ndim == 4:
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import (
QwenImageTransformerBlock,
QwenTimestepProjEmbeddings,
RMSNorm,
compute_text_seq_len_from_mask,
)
@@ -136,7 +137,7 @@ class QwenImageControlNetModel(
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
The [`QwenImageControlNetModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -147,24 +148,39 @@ class QwenImageControlNetModel(
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
Image shapes for RoPE computation.
txt_seq_lens (`List[int]`, *optional*):
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
length.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
the first element is the controlnet block samples.
"""
# Handle deprecated txt_seq_lens parameter
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
"and `encoder_hidden_states_mask`.",
standard_warn=False,
)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -186,32 +202,47 @@ class QwenImageControlNetModel(
temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
# Construct joint attention mask once to avoid reconstructing in every block
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
temb,
image_rotary_emb,
block_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
joint_attention_kwargs=block_attention_kwargs,
)
block_samples = block_samples + (hidden_states,)
@@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[QwenImageControlNetOutput, Tuple]:
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
standard_warn=False,
)
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1:
@@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
@@ -142,6 +142,32 @@ def apply_rotary_emb_qwen(
return x_out.type_as(x)
def compute_text_seq_len_from_mask(
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
"""
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
if encoder_hidden_states_mask is None:
return text_seq_len, None, None
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
raise ValueError(
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
)
if encoder_hidden_states_mask.dtype != torch.bool:
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
return text_seq_len, per_sample_len, encoder_hidden_states_mask
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()
@@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
txt_seq_lens: List[int],
device: torch.device,
txt_seq_lens: Optional[List[int]] = None,
device: torch.device = None,
max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
device: (`torch.device`):
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
device: (`torch.device`, *optional*):
The device on which to perform the RoPE computation.
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
# Handle deprecated txt_seq_lens parameter
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
"Please use `max_txt_seq_len` instead. "
"The new parameter accepts a single int or tensor value representing the maximum text sequence length.",
standard_warn=False,
)
if max_txt_seq_len is None:
# Use max of txt_seq_lens for backward compatibility
max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens
if max_txt_seq_len is None:
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
# Validate batch inference with variable-sized images
if isinstance(video_fhw, list) and len(video_fhw) > 1:
# Check if all instances have the same size
first_fhw = video_fhw[0]
if not all(fhw == first_fhw for fhw in video_fhw):
logger.warning(
"Batch inference with variable-sized images is not currently supported in QwenEmbedRope. "
"All images in the batch should have the same dimensions (frame, height, width). "
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
"for RoPE computation, which may lead to incorrect results for other images in the batch."
)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
@@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module):
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)
if self.scale_rope:
@@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module):
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
max_txt_seq_len: Union[int, torch.Tensor],
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
structures.
max_txt_seq_len (`int` or `torch.Tensor`):
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
device: (`torch.device`, *optional*):
The device on which to perform the RoPE computation.
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
# Validate batch inference with variable-sized images
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
if isinstance(video_fhw, list) and len(video_fhw) > 1:
# Check if this is batch inference (list of layer lists/tuples)
first_entry = video_fhw[0]
if not all(entry == first_entry for entry in video_fhw):
logger.warning(
"Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. "
"All images in the batch should have the same layer structure. "
f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} "
"for RoPE computation, which may lead to incorrect results for other images in the batch."
)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
@@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module):
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
video_freq = self._compute_condition_freqs(frame, height, width, device)
vid_freqs.append(video_freq)
if self.scale_rope:
@@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module):
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module):
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0:
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
joint_key,
@@ -762,14 +848,25 @@ class QwenImageTransformer2DModel(
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
Mask of the input conditions.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
Image shapes for RoPE computation.
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
used to compute RoPE sequence length.
guidance (`torch.Tensor`, *optional*):
Guidance tensor for conditional generation.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_block_samples (*optional*):
ControlNet block samples to add to the transformer blocks.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
@@ -778,6 +875,15 @@ class QwenImageTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
"Please use `encoder_hidden_states_mask` instead. "
"The mask-based approach is more flexible and supports variable-length sequences.",
standard_warn=False,
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -810,6 +916,11 @@ class QwenImageTransformer2DModel(
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
@@ -819,7 +930,17 @@ class QwenImageTransformer2DModel(
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -827,10 +948,10 @@ class QwenImageTransformer2DModel(
block,
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
temb,
image_rotary_emb,
attention_kwargs,
block_attention_kwargs,
modulate_index,
)
@@ -838,10 +959,10 @@ class QwenImageTransformer2DModel(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
joint_attention_kwargs=block_attention_kwargs,
modulate_index=modulate_index,
)

View File

@@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
)
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
@@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs for the denoiser. "
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
"It should contain prompt_embeds/negative_prompt_embeds."
),
),
]
@@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
img_shapes=block_state.img_shapes,
encoder_hidden_states=block_state.prompt_embeds,
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
txt_seq_lens=block_state.txt_seq_lens,
return_dict=False,
)
@@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
@@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())

View File

@@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)
@@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
@@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,

View File

@@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)
@@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
@@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,

View File

@@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
else:
batch_size = prompt_embeds.shape[0]
# QwenImageEditPlusPipeline does not currently support batch_size > 1
if batch_size > 1:
raise ValueError(
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
"Please process prompts one at a time."
)
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
@@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
@@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
additional_t_cond=is_rgb,
return_dict=False,
@@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
additional_t_cond=is_rgb,
return_dict=False,
@@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w

View File

@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder",
)
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder_2",
)
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@nightly
@require_torch_accelerator

View File

@@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 5, 32, 32, 3)
@@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in LTX2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_save_pretrained_with_text_lora(self):
pass

View File

@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 4, 4, 3)
@@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),

View File

@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 7, 16, 16, 3)
@@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()

View File

@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Wan.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_save_load(self):
pass
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()

View File

@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
supports_text_encoder_loras = True
unet_kwargs = None
transformer_cls = None
@@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, _, _ = self.get_dummy_components()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
@@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)

View File

@@ -15,10 +15,10 @@
import unittest
import pytest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
@@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
}
def prepare_init_args_and_inputs_for_common(self):
@@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_infers_text_seq_len_from_mask(self):
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self):
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
]
]
with torch.no_grad():
output = model(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
additional_t_cond=addition_t_cond,
)
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
@@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()