mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fixes
This commit is contained in:
@@ -1252,6 +1252,10 @@ class FluxPosEmbed(nn.Module):
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
was_unbatched = ids.ndim == 2
|
||||
if was_unbatched:
|
||||
# Add a batch dimension to standardize processing
|
||||
ids = ids.unsqueeze(0)
|
||||
# ids is now expected to be [B, S, n_axes]
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
@@ -1277,7 +1281,7 @@ class FluxPosEmbed(nn.Module):
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
|
||||
# Squeeze the batch dim if the original input was unbatched
|
||||
if ids.ndim == 2:
|
||||
if was_unbatched:
|
||||
freqs_cos = freqs_cos.squeeze(0)
|
||||
freqs_sin = freqs_sin.squeeze(0)
|
||||
|
||||
|
||||
@@ -456,23 +456,19 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# if txt_ids.ndim == 3:
|
||||
# logger.warning(
|
||||
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
# txt_ids = txt_ids[0]
|
||||
# if img_ids.ndim == 3:
|
||||
# logger.warning(
|
||||
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
# img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 2:
|
||||
txt_ids = txt_ids.unsqueeze(0)
|
||||
if img_ids.ndim == 2:
|
||||
img_ids = img_ids.unsqueeze(0)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
if txt_ids.ndim == 3:
|
||||
# logger.warning(
|
||||
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
# logger.warning(
|
||||
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
# )
|
||||
img_ids = img_ids[0]
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user