mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
When using split RoPE, make sure that the output dtype is same as input dtype
This commit is contained in:
@@ -47,6 +47,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor
|
||||
def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
|
||||
x_dtype = x.dtype
|
||||
needs_reshape = False
|
||||
if x.ndim != 4 and cos.ndim == 4:
|
||||
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
|
||||
@@ -61,7 +62,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
|
||||
r = last // 2
|
||||
|
||||
# (..., 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
second_x = split_x[..., 1:, :] # (..., 1, r)
|
||||
|
||||
@@ -80,6 +81,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
|
||||
if needs_reshape:
|
||||
out = out.swapaxes(1, 2).reshape(b, t, -1)
|
||||
|
||||
out = out.to(dtype=x_dtype)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user