diff --git a/check_rope_batched.py b/check_rope_batched.py new file mode 100644 index 0000000000..db059048c8 --- /dev/null +++ b/check_rope_batched.py @@ -0,0 +1,14 @@ +from diffusers.models.embeddings import FluxPosEmbed +import torch + +batch_size = 4 +seq_length = 16 +img_seq_length = 32 +txt_ids = torch.randn(batch_size, seq_length, 3) +img_ids = torch.randn(batch_size, img_seq_length, 3) + +pos_embed = FluxPosEmbed(theta=10000, axes_dim=[4, 4, 8]) +ids = torch.cat((txt_ids, img_ids), dim=1) +image_rotary_emb = pos_embed(ids) +# image_rotary_emb[0].shape=torch.Size([4, 48, 16]), image_rotary_emb[1].shape=torch.Size([4, 48, 16]) +print(f"{image_rotary_emb[0].shape=}, {image_rotary_emb[1].shape=}") diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4f268bfa01..c24e1d223f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1142,32 +1142,38 @@ def get_1d_rotary_pos_embed( """ assert dim % 2 == 0 - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] + # Handle both batched [B, S] and un-batched [S] inputs + if pos.ndim == 1: + pos = pos.unsqueeze(0) # Add a batch dimension if missing theta = theta * ntk_factor freqs = ( 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + ) # Shape: [D/2] + + # Replace torch.outer with broadcasted multiplication + # Old: freqs = torch.outer(pos, freqs) # Shape: [S, D/2] + # New: pos is [B, S], freqs is [D/2]. Unsqueeze pos to [B, S, 1] for broadcasting. + freqs = pos.unsqueeze(-1) * freqs # Shape: [B, S, D/2] + is_npu = freqs.device.type == "npu" if is_npu: freqs = freqs.float() + if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + # Use dim=-1 for robust interleaving on the feature dimension + freqs_cos = freqs.cos().repeat_interleave(2, dim=-1) # Shape: [B, S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=-1) # Shape: [B, S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # Shape: [B, S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # Shape: [B, S, D] return freqs_cos, freqs_sin else: # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Shape: [B, S, D/2] return freqs_cis @@ -1246,6 +1252,7 @@ class FluxPosEmbed(nn.Module): self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: + # ids is now expected to be [B, S, n_axes] n_axes = ids.shape[-1] cos_out = [] sin_out = [] @@ -1253,10 +1260,11 @@ class FluxPosEmbed(nn.Module): is_mps = ids.device.type == "mps" is_npu = ids.device.type == "npu" freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], - pos[:, i], + pos[:, :, i], # Correct slicing for batched input theta=self.theta, repeat_interleave_real=True, use_real=True, @@ -1264,8 +1272,15 @@ class FluxPosEmbed(nn.Module): ) cos_out.append(cos) sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + + # Squeeze the batch dim if the original input was unbatched + if ids.ndim == 2: + freqs_cos = freqs_cos.squeeze(0) + freqs_sin = freqs_sin.squeeze(0) + return freqs_cos, freqs_sin diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3af1de2ad0..94ccf0a3da 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -456,20 +456,23 @@ class FluxTransformer2DModel( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) + # if txt_ids.ndim == 3: + # logger.warning( + # "Passing `txt_ids` 3d torch.Tensor is deprecated." + # "Please remove the batch dimension and pass it as a 2d torch Tensor" + # ) + # txt_ids = txt_ids[0] + # if img_ids.ndim == 3: + # logger.warning( + # "Passing `img_ids` 3d torch.Tensor is deprecated." + # "Please remove the batch dimension and pass it as a 2d torch Tensor" + # ) + # img_ids = img_ids[0] + if txt_ids.ndim == 2: + txt_ids = txt_ids.unsqueeze(0) + if img_ids.ndim == 2: + img_ids = img_ids.unsqueeze(0) + ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: