1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-07-15 12:25:51 +02:00
parent 576da52f45
commit e909b7355f
2 changed files with 12 additions and 4 deletions

View File

@@ -1176,6 +1176,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1193,8 +1194,15 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
if sequence_dim == 2:
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
elif sequence_dim == 1:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:

View File

@@ -108,8 +108,8 @@ class FluxAttnProcessor:
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend