1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)

* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------

Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
This commit is contained in:
Charchit Sharma
2025-11-12 03:15:36 +05:30
committed by GitHub
parent 66e6a0215f
commit ff263947ad
2 changed files with 11 additions and 10 deletions

View File

@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_cos = []
freqs_sin = []
@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

View File

@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)