mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Use real-valued instead of complex tensors in Wan2.1 RoPE (#11649)
* use real instead of complex tensors in Wan2.1 RoPE * remove the redundant type conversion * unpack rotary_emb * register rotary embedding frequencies as non-persistent buffers * Apply style fixes --------- Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -71,14 +71,22 @@ class WanAttnProcessor2_0:
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
|
||||
x1, x2 = x[..., 0], x[..., 1]
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
query = apply_rotary_emb(query, *rotary_emb)
|
||||
key = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
@@ -179,7 +187,11 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -189,36 +201,52 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
freqs = []
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs.append(freq)
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = freqs.split_with_sizes(
|
||||
[
|
||||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||||
self.attention_head_dim // 6,
|
||||
self.attention_head_dim // 6,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
return freqs
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class WanTransformerBlock(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user