mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -1176,6 +1176,7 @@ def apply_rotary_emb(
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
sequence_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
@@ -1193,8 +1194,15 @@ def apply_rotary_emb(
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
if sequence_dim == 2:
|
||||
cos = cos[None, None, :, :]
|
||||
sin = sin[None, None, :, :]
|
||||
elif sequence_dim == 1:
|
||||
cos = cos[None, :, None, :]
|
||||
sin = sin[None, :, None, :]
|
||||
else:
|
||||
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
||||
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
|
||||
@@ -108,8 +108,8 @@ class FluxAttnProcessor:
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
|
||||
Reference in New Issue
Block a user