1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-07-07 14:22:10 +05:30
parent 666a3d9448
commit 6cc6c130cf
2 changed files with 18 additions and 18 deletions

View File

@@ -1252,6 +1252,10 @@ class FluxPosEmbed(nn.Module):
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
was_unbatched = ids.ndim == 2
if was_unbatched:
# Add a batch dimension to standardize processing
ids = ids.unsqueeze(0)
# ids is now expected to be [B, S, n_axes]
n_axes = ids.shape[-1]
cos_out = []
@@ -1277,7 +1281,7 @@ class FluxPosEmbed(nn.Module):
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
# Squeeze the batch dim if the original input was unbatched
if ids.ndim == 2:
if was_unbatched:
freqs_cos = freqs_cos.squeeze(0)
freqs_sin = freqs_sin.squeeze(0)

View File

@@ -456,23 +456,19 @@ class FluxTransformer2DModel(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# if txt_ids.ndim == 3:
# logger.warning(
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
# txt_ids = txt_ids[0]
# if img_ids.ndim == 3:
# logger.warning(
# "Passing `img_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
# img_ids = img_ids[0]
if txt_ids.ndim == 2:
txt_ids = txt_ids.unsqueeze(0)
if img_ids.ndim == 2:
img_ids = img_ids.unsqueeze(0)
ids = torch.cat((txt_ids, img_ids), dim=1)
if txt_ids.ndim == 3:
# logger.warning(
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
# logger.warning(
# "Passing `img_ids` 3d torch.Tensor is deprecated."
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
# )
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: