diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2f31054319..9c41bf949e 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -47,6 +47,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = freqs + x_dtype = x.dtype needs_reshape = False if x.ndim != 4 and cos.ndim == 4: # cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head) @@ -61,7 +62,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten r = last // 2 # (..., 2, r) - split_x = x.reshape(*x.shape[:-1], 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float first_x = split_x[..., :1, :] # (..., 1, r) second_x = split_x[..., 1:, :] # (..., 1, r) @@ -80,6 +81,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten if needs_reshape: out = out.swapaxes(1, 2).reshape(b, t, -1) + out = out.to(dtype=x_dtype) return out