1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[qwen-image] edit 2511 support (#12839)

* [qwen-image] edit 2511 support

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
naykun
2025-12-15 15:05:01 +08:00
committed by GitHub
parent 17c0e79dbd
commit b8a4cbac14

View File

@@ -14,6 +14,7 @@
import functools
import math
from math import prod
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -363,7 +364,13 @@ class QwenDoubleStreamAttnProcessor2_0:
@maybe_allow_in_graph
class QwenImageTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
zero_cond_t: bool = False,
):
super().__init__()
@@ -403,10 +410,43 @@ class QwenImageTransformerBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def _modulate(self, x, mod_params):
self.zero_cond_t = zero_cond_t
def _modulate(self, x, mod_params, index=None):
"""Apply modulation to input tensor"""
# x: b l d, shift: b d, scale: b d, gate: b d
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
if index is not None:
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
# So shift, scale, gate have shape [2*actual_batch, d]
actual_batch = shift.size(0) // 2
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
# index: [b, l] where b is actual batch size
# Expand to [b, l, 1] to match feature dimension
index_expanded = index.unsqueeze(-1) # [b, l, 1]
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
scale_0_exp = scale_0.unsqueeze(1)
scale_1_exp = scale_1.unsqueeze(1)
gate_0_exp = gate_0.unsqueeze(1)
gate_1_exp = gate_1.unsqueeze(1)
# Use torch.where to select based on index
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
else:
shift_result = shift.unsqueeze(1)
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
@@ -416,9 +456,13 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
modulate_index: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
if self.zero_cond_t:
temb = torch.chunk(temb, 2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
@@ -427,7 +471,7 @@ class QwenImageTransformerBlock(nn.Module):
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
@@ -457,7 +501,7 @@ class QwenImageTransformerBlock(nn.Module):
# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
@@ -533,6 +577,7 @@ class QwenImageTransformer2DModel(
joint_attention_dim: int = 3584,
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -553,6 +598,7 @@ class QwenImageTransformer2DModel(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
zero_cond_t=zero_cond_t,
)
for _ in range(num_layers)
]
@@ -562,6 +608,7 @@ class QwenImageTransformer2DModel(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t
def forward(
self,
@@ -618,6 +665,17 @@ class QwenImageTransformer2DModel(
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
@@ -641,6 +699,8 @@ class QwenImageTransformer2DModel(
encoder_hidden_states_mask,
temb,
image_rotary_emb,
attention_kwargs,
modulate_index,
)
else:
@@ -651,6 +711,7 @@ class QwenImageTransformer2DModel(
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)
# controlnet residual
@@ -659,6 +720,8 @@ class QwenImageTransformer2DModel(
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
if self.zero_cond_t:
temb = temb.chunk(2, dim=0)[0]
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)