From e909b7355fcd7055334e245e532e79349df79a92 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:25:51 +0200 Subject: [PATCH] update --- src/diffusers/models/embeddings.py | 12 ++++++++++-- .../models/transformers/transformer_flux.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 262e57a3a0..4d3d246e48 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1176,6 +1176,7 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, + sequence_dim: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -1193,8 +1194,15 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 2d4bc172a7..e5b45bbcd5 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -108,8 +108,8 @@ class FluxAttnProcessor: value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend