diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 533eb356e0..c7b5ca5186 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -136,8 +136,8 @@ class PhotonAttnProcessor2_0: return attn_output -# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class EmbedND(nn.Module): +# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class PhotoEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. @@ -672,7 +672,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size)