From 6cc6c130cfb75c268d1ade575720f973803eef30 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 7 Jul 2025 14:22:10 +0530 Subject: [PATCH] fixes --- src/diffusers/models/embeddings.py | 6 +++- .../models/transformers/transformer_flux.py | 30 ++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 40e559bd75..83bae7785b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1252,6 +1252,10 @@ class FluxPosEmbed(nn.Module): self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: + was_unbatched = ids.ndim == 2 + if was_unbatched: + # Add a batch dimension to standardize processing + ids = ids.unsqueeze(0) # ids is now expected to be [B, S, n_axes] n_axes = ids.shape[-1] cos_out = [] @@ -1277,7 +1281,7 @@ class FluxPosEmbed(nn.Module): freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) # Squeeze the batch dim if the original input was unbatched - if ids.ndim == 2: + if was_unbatched: freqs_cos = freqs_cos.squeeze(0) freqs_sin = freqs_sin.squeeze(0) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 94ccf0a3da..086268f098 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -456,23 +456,19 @@ class FluxTransformer2DModel( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - # if txt_ids.ndim == 3: - # logger.warning( - # "Passing `txt_ids` 3d torch.Tensor is deprecated." - # "Please remove the batch dimension and pass it as a 2d torch Tensor" - # ) - # txt_ids = txt_ids[0] - # if img_ids.ndim == 3: - # logger.warning( - # "Passing `img_ids` 3d torch.Tensor is deprecated." - # "Please remove the batch dimension and pass it as a 2d torch Tensor" - # ) - # img_ids = img_ids[0] - if txt_ids.ndim == 2: - txt_ids = txt_ids.unsqueeze(0) - if img_ids.ndim == 2: - img_ids = img_ids.unsqueeze(0) - ids = torch.cat((txt_ids, img_ids), dim=1) + if txt_ids.ndim == 3: + # logger.warning( + # "Passing `txt_ids` 3d torch.Tensor is deprecated." + # "Please remove the batch dimension and pass it as a 2d torch Tensor" + # ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + # logger.warning( + # "Passing `img_ids` 3d torch.Tensor is deprecated." + # "Please remove the batch dimension and pass it as a 2d torch Tensor" + # ) + img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: