mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[qwen-image] edit 2511 support (#12839)
* [qwen-image] edit 2511 support * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -363,7 +364,13 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
@maybe_allow_in_graph
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
zero_cond_t: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -403,10 +410,43 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
self.zero_cond_t = zero_cond_t
|
||||
|
||||
def _modulate(self, x, mod_params, index=None):
|
||||
"""Apply modulation to input tensor"""
|
||||
# x: b l d, shift: b d, scale: b d, gate: b d
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
if index is not None:
|
||||
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
|
||||
# So shift, scale, gate have shape [2*actual_batch, d]
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
|
||||
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
||||
|
||||
# index: [b, l] where b is actual batch size
|
||||
# Expand to [b, l, 1] to match feature dimension
|
||||
index_expanded = index.unsqueeze(-1) # [b, l, 1]
|
||||
|
||||
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
|
||||
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
|
||||
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
|
||||
scale_0_exp = scale_0.unsqueeze(1)
|
||||
scale_1_exp = scale_1.unsqueeze(1)
|
||||
gate_0_exp = gate_0.unsqueeze(1)
|
||||
gate_1_exp = gate_1.unsqueeze(1)
|
||||
|
||||
# Use torch.where to select based on index
|
||||
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
|
||||
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
|
||||
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
|
||||
else:
|
||||
shift_result = shift.unsqueeze(1)
|
||||
scale_result = scale.unsqueeze(1)
|
||||
gate_result = gate.unsqueeze(1)
|
||||
|
||||
return x * (1 + scale_result) + shift_result, gate_result
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -416,9 +456,13 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
modulate_index: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Get modulation parameters for both streams
|
||||
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
||||
|
||||
if self.zero_cond_t:
|
||||
temb = torch.chunk(temb, 2, dim=0)[0]
|
||||
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
||||
|
||||
# Split modulation parameters for norm1 and norm2
|
||||
@@ -427,7 +471,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
# Process image stream - norm1 + modulation
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index)
|
||||
|
||||
# Process text stream - norm1 + modulation
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
@@ -457,7 +501,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
# Process image stream - norm2 + MLP
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index)
|
||||
img_mlp_output = self.img_mlp(img_modulated2)
|
||||
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
||||
|
||||
@@ -533,6 +577,7 @@ class QwenImageTransformer2DModel(
|
||||
joint_attention_dim: int = 3584,
|
||||
guidance_embeds: bool = False, # TODO: this should probably be removed
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
zero_cond_t: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -553,6 +598,7 @@ class QwenImageTransformer2DModel(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
zero_cond_t=zero_cond_t,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -562,6 +608,7 @@ class QwenImageTransformer2DModel(
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.zero_cond_t = zero_cond_t
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -618,6 +665,17 @@ class QwenImageTransformer2DModel(
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
if self.zero_cond_t:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
modulate_index = torch.tensor(
|
||||
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
|
||||
device=timestep.device,
|
||||
dtype=torch.int,
|
||||
)
|
||||
else:
|
||||
modulate_index = None
|
||||
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
@@ -641,6 +699,8 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states_mask,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
modulate_index,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -651,6 +711,7 @@ class QwenImageTransformer2DModel(
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
modulate_index=modulate_index,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
@@ -659,6 +720,8 @@ class QwenImageTransformer2DModel(
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
if self.zero_cond_t:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
# Use only the image part (hidden_states) from the dual-stream blocks
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user