mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into feat/distill-ltx2
This commit is contained in:
@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### torch.compile
|
||||
|
||||
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
|
||||
# First call triggers compilation (~7s overhead)
|
||||
# Subsequent calls run at ~2.4x faster
|
||||
image = pipe("a cat", num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
### Batched Inference with Variable-Length Prompts
|
||||
|
||||
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
|
||||
|
||||
```python
|
||||
# CFG with different prompt lengths works correctly
|
||||
image = pipe(
|
||||
prompt="A cat",
|
||||
negative_prompt="blurry, low quality, distorted",
|
||||
true_cfg_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
### Unified Attention
|
||||
|
||||
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
|
||||
|
||||
This hybrid approach leverages the strengths of both methods:
|
||||
- **Ulysses Attention** efficiently parallelizes across attention heads
|
||||
- **Ring Attention** handles very long sequences with minimal memory overhead
|
||||
- Together, they enable 2D parallelization across both heads and sequence dimensions
|
||||
|
||||
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
|
||||
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
|
||||
|
||||
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|
||||
|--------------------|------------------|-------------|------------------|
|
||||
| ulysses | 6670.789 | 7.50 | 33.85 |
|
||||
| ring | 13076.492 | 3.82 | 56.02 |
|
||||
| unified_balanced | 11068.705 | 4.52 | 33.85 |
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.
|
||||
|
||||
@@ -1695,9 +1695,13 @@ def main(args):
|
||||
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
|
||||
|
||||
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
|
||||
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
|
||||
device=cond_model_input.device
|
||||
)
|
||||
cond_model_input_ids = cond_model_input_ids.view(
|
||||
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
@@ -1724,6 +1728,9 @@ def main(args):
|
||||
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
|
||||
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
|
||||
|
||||
orig_input_shape = packed_noisy_model_input.shape
|
||||
orig_input_ids_shape = model_input_ids.shape
|
||||
|
||||
# concatenate the model inputs with the cond inputs
|
||||
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
|
||||
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
|
||||
@@ -1742,7 +1749,8 @@ def main(args):
|
||||
img_ids=model_input_ids, # B, image_seq_len, 4
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
|
||||
model_pred = model_pred[:, : orig_input_shape[1], :]
|
||||
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
|
||||
|
||||
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
|
||||
|
||||
|
||||
@@ -1513,14 +1513,12 @@ def main(args):
|
||||
height=model_input.shape[3],
|
||||
width=model_input.shape[4],
|
||||
)
|
||||
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
timestep=timesteps / 1000,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = QwenImagePipeline._unpack_latents(
|
||||
|
||||
@@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if not (has_lora_keys or has_norm_keys):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
@@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
@@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
|
||||
@@ -90,10 +90,6 @@ class ContextParallelConfig:
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
|
||||
"""
|
||||
Perform dimension sharding / reassembly across processes using _all_to_all_single.
|
||||
|
||||
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
|
||||
head dimension flexibly by accepting scatter_idx and gather_idx.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input tensor. Expected shapes:
|
||||
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
|
||||
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
|
||||
scatter_idx (int) :
|
||||
Dimension along which the tensor is partitioned before all-to-all.
|
||||
gather_idx (int):
|
||||
Dimension along which the output is reassembled after all-to-all.
|
||||
group :
|
||||
Distributed process group for the Ulysses group.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with globally exchanged dimensions.
|
||||
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
|
||||
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
|
||||
"""
|
||||
group_world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
if scatter_idx == 2 and gather_idx == 1:
|
||||
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
|
||||
# dimension and scatters head dimension
|
||||
batch_size, seq_len_local, num_heads, head_dim = x.shape
|
||||
seq_len = seq_len_local * group_world_size
|
||||
num_heads_local = num_heads // group_world_size
|
||||
|
||||
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
|
||||
.transpose(0, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
out = _all_to_all_single(x_temp, group=group)
|
||||
else:
|
||||
out = x_temp
|
||||
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
|
||||
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
|
||||
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
|
||||
return out
|
||||
elif scatter_idx == 1 and gather_idx == 2:
|
||||
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
|
||||
# scatters back the sequence dimension.
|
||||
batch_size, seq_len, num_heads_local, head_dim = x.shape
|
||||
num_heads = num_heads_local * group_world_size
|
||||
seq_len_local = seq_len // group_world_size
|
||||
|
||||
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
|
||||
.permute(1, 3, 2, 0, 4)
|
||||
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
output = _all_to_all_single(x_temp, group)
|
||||
else:
|
||||
output = x_temp
|
||||
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
|
||||
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
|
||||
|
||||
|
||||
class SeqAllToAllDim(torch.autograd.Function):
|
||||
"""
|
||||
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
|
||||
for more info.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, group, input, scatter_id=2, gather_id=1):
|
||||
ctx.group = group
|
||||
ctx.scatter_id = scatter_id
|
||||
ctx.gather_id = gather_id
|
||||
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
grad_input = SeqAllToAllDim.apply(
|
||||
ctx.group,
|
||||
grad_outputs,
|
||||
ctx.gather_id, # reversed
|
||||
ctx.scatter_id, # reversed
|
||||
)
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
lse = lse.unsqueeze(-1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
|
||||
@@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
|
||||
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
):
|
||||
"""
|
||||
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
|
||||
"""
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
ulysses_group = ulysses_mesh.get_group()
|
||||
|
||||
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
|
||||
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
|
||||
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
|
||||
out = TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
context_layer, lse, *_ = out
|
||||
else:
|
||||
context_layer = out
|
||||
# context_layer is of shape (B, S, H_LOCAL, D)
|
||||
output = SeqAllToAllDim.apply(
|
||||
ulysses_group,
|
||||
context_layer,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
return (output, lse)
|
||||
return output
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
|
||||
raise ValueError("GQA is not yet supported for templated attention.")
|
||||
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
if (
|
||||
_parallel_config.context_parallel_config.ring_degree > 1
|
||||
and _parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
):
|
||||
return _templated_unified_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
@@ -1945,6 +2125,43 @@ def _native_flex_attention(
|
||||
return out
|
||||
|
||||
|
||||
def _prepare_additive_attn_mask(
|
||||
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
|
||||
|
||||
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
|
||||
|
||||
Args:
|
||||
attn_mask: 2D tensor [batch_size, seq_len_k]
|
||||
- Boolean: True means attend, False means mask out
|
||||
- Additive: 0.0 means attend, -inf means mask out
|
||||
target_dtype: The dtype to convert the mask to (usually query.dtype)
|
||||
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
|
||||
|
||||
Returns:
|
||||
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
|
||||
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
|
||||
"""
|
||||
# Check if the mask is boolean or already additive
|
||||
if attn_mask.dtype == torch.bool:
|
||||
# Convert boolean to additive: True -> 0.0, False -> -inf
|
||||
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
|
||||
# Convert to target dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
else:
|
||||
# Already additive mask - just ensure correct dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
|
||||
# Optionally reshape to 4D for broadcasting in attention mechanisms
|
||||
if reshape_4d:
|
||||
batch_size, seq_len_k = attn_mask.shape
|
||||
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
@@ -1964,6 +2181,19 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
# Reshape 2D mask to 4D for SDPA
|
||||
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
|
||||
# SDPA handles both boolean and additive masks correctly
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -2530,10 +2760,34 @@ def _xformers_attention(
|
||||
attn_mask = xops.LowerTriangularMask()
|
||||
elif attn_mask is not None:
|
||||
if attn_mask.ndim == 2:
|
||||
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
||||
# Convert 2D mask to 4D for xformers
|
||||
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
|
||||
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
|
||||
# Need memory alignment - create larger tensor and slice for alignment
|
||||
original_seq_len = attn_mask.size(1)
|
||||
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
|
||||
|
||||
# Create aligned 4D tensor and slice to ensure proper memory layout
|
||||
aligned_mask = torch.zeros(
|
||||
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
# Convert to 4D additive mask (handles both boolean and additive inputs)
|
||||
mask_additive = _prepare_additive_attn_mask(
|
||||
attn_mask, target_dtype=query.dtype
|
||||
) # [batch, 1, 1, seq_len_k]
|
||||
# Broadcast to [batch, heads, seq_q, seq_len_k]
|
||||
aligned_mask[:, :, :, :original_seq_len] = mask_additive
|
||||
# Mask out the padding (already -inf from zeros -> where with default)
|
||||
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
|
||||
|
||||
# Slice to actual size with proper alignment
|
||||
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
|
||||
elif attn_mask.ndim != 4:
|
||||
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
elif attn_mask.ndim == 4:
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
|
||||
if enable_gqa:
|
||||
if num_heads_q % num_heads_kv != 0:
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import (
|
||||
QwenImageTransformerBlock,
|
||||
QwenTimestepProjEmbeddings,
|
||||
RMSNorm,
|
||||
compute_text_seq_len_from_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -136,7 +137,7 @@ class QwenImageControlNetModel(
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`QwenImageControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
@@ -147,24 +148,39 @@ class QwenImageControlNetModel(
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*):
|
||||
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
|
||||
length.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
|
||||
the first element is the controlnet block samples.
|
||||
"""
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
|
||||
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
|
||||
"and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
@@ -186,32 +202,47 @@ class QwenImageControlNetModel(
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
@@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
|
||||
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
|
||||
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
@@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -142,6 +142,32 @@ def apply_rotary_emb_qwen(
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
|
||||
) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
|
||||
"""
|
||||
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
|
||||
if encoder_hidden_states_mask is None:
|
||||
return text_seq_len, None, None
|
||||
|
||||
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
|
||||
raise ValueError(
|
||||
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
|
||||
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
|
||||
)
|
||||
|
||||
if encoder_hidden_states_mask.dtype != torch.bool:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
|
||||
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
@@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
txt_seq_lens: List[int],
|
||||
device: torch.device,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
device: torch.device = None,
|
||||
max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
txt_seq_lens (`List[int]`):
|
||||
A list of integers of length batch_size representing the length of each text prompt.
|
||||
device: (`torch.device`):
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `max_txt_seq_len` instead. "
|
||||
"The new parameter accepts a single int or tensor value representing the maximum text sequence length.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if max_txt_seq_len is None:
|
||||
# Use max of txt_seq_lens for backward compatibility
|
||||
max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens
|
||||
|
||||
if max_txt_seq_len is None:
|
||||
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
|
||||
|
||||
# Validate batch inference with variable-sized images
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if all instances have the same size
|
||||
first_fhw = video_fhw[0]
|
||||
if not all(fhw == first_fhw for fhw in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedRope. "
|
||||
"All images in the batch should have the same dimensions (frame, height, width). "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module):
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
max_txt_seq_len: Union[int, torch.Tensor],
|
||||
device: torch.device = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
|
||||
structures.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Validate batch inference with variable-sized images
|
||||
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if this is batch inference (list of layer lists/tuples)
|
||||
first_entry = video_fhw[0]
|
||||
if not all(entry == first_entry for entry in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. "
|
||||
"All images in the batch should have the same layer structure. "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_condition_freqs(frame, height, width, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Compute joint attention
|
||||
joint_hidden_states = dispatch_attention_fn(
|
||||
joint_query,
|
||||
joint_key,
|
||||
@@ -762,14 +848,25 @@ class QwenImageTransformer2DModel(
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
||||
Mask of the input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
|
||||
used to compute RoPE sequence length.
|
||||
guidance (`torch.Tensor`, *optional*):
|
||||
Guidance tensor for conditional generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_block_samples (*optional*):
|
||||
ControlNet block samples to add to the transformer blocks.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
@@ -778,6 +875,15 @@ class QwenImageTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `encoder_hidden_states_mask` instead. "
|
||||
"The mask-based approach is more flexible and supports variable-length sequences.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -810,6 +916,11 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
@@ -819,7 +930,17 @@ class QwenImageTransformer2DModel(
|
||||
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
|
||||
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -827,10 +948,10 @@ class QwenImageTransformer2DModel(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
block_attention_kwargs,
|
||||
modulate_index,
|
||||
)
|
||||
|
||||
@@ -838,10 +959,10 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
modulate_index=modulate_index,
|
||||
)
|
||||
|
||||
|
||||
@@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
)
|
||||
]
|
||||
] * block_state.batch_size
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
|
||||
"It should contain prompt_embeds/negative_prompt_embeds."
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
img_shapes=block_state.img_shapes,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
|
||||
txt_seq_lens=block_state.txt_seq_lens,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
@@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
|
||||
@@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# QwenImageEditPlusPipeline does not currently support batch_size > 1
|
||||
if batch_size > 1:
|
||||
raise ValueError(
|
||||
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
|
||||
"Please process prompts one at a time."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
@@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
|
||||
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder_2",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 5, 32, 32, 3)
|
||||
@@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 4, 3)
|
||||
@@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
|
||||
@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 7, 16, 16, 3)
|
||||
@@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
|
||||
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
|
||||
supports_text_encoder_loras = True
|
||||
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
@@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests:
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
|
||||
text_lora_config = LoraConfig(
|
||||
@@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
@@ -15,10 +15,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
@@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
]
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
@@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
Reference in New Issue
Block a user