diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index 1dcf3f944f..ee3dd3b28e 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### torch.compile + +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index b7cd0e20f4..b7255f74af 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, ).to(device) ``` +### Unified Attention + +[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout. + +This hybrid approach leverages the strengths of both methods: +- **Ulysses Attention** efficiently parallelizes across attention heads +- **Ring Attention** handles very long sequences with minimal memory overhead +- Together, they enable 2D parallelization across both heads and sequence dimensions + +[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping). +Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) +``` + +> [!TIP] +> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). + +We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | +|--------------------|------------------|-------------|------------------| +| ulysses | 6670.789 | 7.50 | 33.85 | +| ring | 13076.492 | 3.82 | 56.02 | +| unified_balanced | 11068.705 | 4.52 | 33.85 | + +From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention. diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 419821e8a8..5af0906642 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1695,9 +1695,13 @@ def main(args): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1724,6 +1728,9 @@ def main(args): packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) @@ -1742,7 +1749,8 @@ def main(args): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 53b01bf0cf..ea9b137b0a 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,14 +1513,12 @@ def main(args): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5fc650a80d..24d1fd7b93 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_peft_state_dict = { k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") @@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin): is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c7..1c7703a13c 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,6 @@ class ContextParallelConfig: ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index bdb9c4861a..f4ec497038 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + """ + Perform dimension sharding / reassembly across processes using _all_to_all_single. + + This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or + head dimension flexibly by accepting scatter_idx and gather_idx. + + Args: + x (torch.Tensor): + Input tensor. Expected shapes: + - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) + - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) + scatter_idx (int) : + Dimension along which the tensor is partitioned before all-to-all. + gather_idx (int): + Dimension along which the output is reassembled after all-to-all. + group : + Distributed process group for the Ulysses group. + + Returns: + torch.Tensor: Tensor with globally exchanged dimensions. + - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) + - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) + """ + group_world_size = torch.distributed.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + # dimension and scatters head dimension + batch_size, seq_len_local, num_heads, head_dim = x.shape + seq_len = seq_len_local * group_world_size + num_heads_local = num_heads // group_world_size + + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D + x_temp = ( + x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) + .transpose(0, 2) + .contiguous() + ) + + if group_world_size > 1: + out = _all_to_all_single(x_temp, group=group) + else: + out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D + out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() + out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) + return out + elif scatter_idx == 1 and gather_idx == 2: + # Used after ulysses sequence parallel in unified SP. gathers the head dimension + # scatters back the sequence dimension. + batch_size, seq_len, num_heads_local, head_dim = x.shape + num_heads = num_heads_local * group_world_size + seq_len_local = seq_len // group_world_size + + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) + + if group_world_size > 1: + output = _all_to_all_single(x_temp, group) + else: + output = x_temp + output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() + output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) + return output + else: + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") + + +class SeqAllToAllDim(torch.autograd.Function): + """ + all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange + for more info. + """ + + @staticmethod + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) + + @staticmethod + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function): out = out.to(torch.float32) lse = lse.to(torch.float32) - lse = lse.unsqueeze(-1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) @@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function): grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function): x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None + + +def _templated_unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + scatter_idx: int = 2, + gather_idx: int = 1, +): + """ + Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 + """ + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # lse is of shape (B, S, H_LOCAL, 1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( @@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if ( + _parallel_config.context_parallel_config.ring_degree > 1 + and _parallel_config.context_parallel_config.ulysses_degree > 1 + ): + return _templated_unified_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key, @@ -1945,6 +2125,43 @@ def _native_flex_attention( return out +def _prepare_additive_attn_mask( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. + + Args: + attn_mask: 2D tensor [batch_size, seq_len_k] + - Boolean: True means attend, False means mask out + - Additive: 0.0 means attend, -inf means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Check if the mask is boolean or already additive + if attn_mask.dtype == torch.bool: + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + else: + # Already additive mask - just ensure correct dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -1964,6 +2181,19 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Reshape 2D mask to 4D for SDPA + # SDPA accepts both boolean masks (torch.bool) and additive masks (float) + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] + # SDPA handles both boolean and additive masks correctly + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2530,10 +2760,34 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D mask to 4D for xformers + # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) + # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Convert to 4D additive mask (handles both boolean and additive inputs) + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 8697127178..fa374285ee 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import ( QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -136,7 +137,7 @@ class QwenImageControlNetModel( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -147,24 +148,39 @@ class QwenImageControlNetModel( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -186,32 +202,47 @@ class QwenImageControlNetModel( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Construct joint attention mask once to avoid reconstructing in every block + block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, + block_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 1229bab169..a8c98201d9 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -142,6 +142,32 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], - device: torch.device, + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. - device: (`torch.device`): + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", + standard_warn=False, + ) + if max_txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module): else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width) - video_freq = video_freq.to(device) + video_freq = self._compute_condition_freqs(frame, height, width, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module): max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module): return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width): + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0: joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, joint_key, @@ -762,14 +848,25 @@ class QwenImageTransformer2DModel( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -778,6 +875,15 @@ class QwenImageTransformer2DModel( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -810,6 +916,11 @@ class QwenImageTransformer2DModel( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -819,7 +930,17 @@ class QwenImageTransformer2DModel( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + # Construct joint attention mask once to avoid reconstructing in every block + # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -827,10 +948,10 @@ class QwenImageTransformer2DModel( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, - attention_kwargs, + block_attention_kwargs, modulate_index, ) @@ -838,10 +959,10 @@ class QwenImageTransformer2DModel( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, modulate_index=modulate_index, ) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index e14164229c..d9c8cbb01d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): ) ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) self.set_block_state(state, block_state) @@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index eb1e5a341c..d6bcb4a94f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + "It should contain prompt_embeds/negative_prompt_embeds." ), ), ] @@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks): getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) @@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks): getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b9..bc3ce84e10 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93..ce6fc974a5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab5..77d78a5ca7 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8..dd723460a5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa..cf467203a9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf16..257e2d846c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): else: batch_size = prompt_embeds.shape[0] + # QwenImageEditPlusPipeline does not currently support batch_size > 1 + if batch_size > 1: + raise ValueError( + f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " + "Please process prompts one at a time." + ) + device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): @@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016..e0b41b8b87 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2..83f02539b1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 7bb12c26ba..53d2c169ee 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as latents = latents[:, :, 1:] # remove the first frame as it is the orgin input - latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56..78ef4ce151 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2..7bd54b77ca 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb63..e8ee6e7a7d 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 4ae189aceb..d970b7d784 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Flux2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a..e59bc5662f 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder_2", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py index 886ae70b7d..0a4b14454f 100644 --- a/tests/lora/test_lora_layers_ltx2.py +++ b/tests/lora/test_lora_layers_ltx2.py @@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 5, 32, 32, 3) @@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTX2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_save_pretrained_with_text_lora(self): - pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e51..095e5b577c 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33..da032229a7 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 4, 4, 3) @@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db..ee82541129 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 7, 16, 16, 3) @@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20..73fd026a67 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ) denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index a860b7b44f..97bf5cbba9 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference_denoiser(self): return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b41..5ae16ab4b9 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9..c8acaea9be 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index 35d1389d96..8432ea56a6 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @unittest.skip("Not supported in ZImage.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5fae6cac0a..efa49b9f48 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests: tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + supports_text_encoder_loras = True unet_kwargs = None transformer_cls = None @@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA. """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests: with different ranks and some adapters removed and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, _, _ = self.get_dummy_components() # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503..384954dfba 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) + + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) + + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask + with torch.no_grad(): + output = model(**inputs) + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(int(per_sample_len.max().item()), 5) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 3, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) + "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + additional_t_cond=addition_t_cond, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break()