mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers - Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__ - Use stored dimensions in forward() instead of recalculating with different formula - Fixes inconsistency between init (using // 6) and forward (using // 3) - Ensures split_sizes matches the dimensions used to create rotary embeddings * quality fix --------- Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
This commit is contained in:
@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
Reference in New Issue
Block a user