diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b80f33e222..6c94e9f67a 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -411,57 +411,54 @@ class PhotonBlock(nn.Module): def forward( self, - img: Tensor, - txt: Tensor, - vec: Tensor, - pe: Tensor, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + image_rotary_emb: Tensor, attention_mask: Tensor | None = None, - **_: dict[str, Any], + **kwargs: dict[str, Any], ) -> Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections. Parameters: - img (`torch.Tensor`): + hidden_states (`torch.Tensor`): Image tokens of shape `(B, L_img, hidden_size)`. - txt (`torch.Tensor`): + encoder_hidden_states (`torch.Tensor`): Text tokens of shape `(B, L_txt, hidden_size)`. - vec (`torch.Tensor`): + temb (`torch.Tensor`): Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or broadcastable). - pe (`torch.Tensor`): + image_rotary_emb (`torch.Tensor`): Rotary positional embeddings applied inside attention. attention_mask (`torch.Tensor`, *optional*): Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. - **_: - Ignored additional keyword arguments for API compatibility. + **kwargs: + Additional keyword arguments for API compatibility. Returns: `torch.Tensor`: Updated image tokens of shape `(B, L_img, hidden_size)`. """ - mod_attn, mod_mlp = self.modulation(vec) + mod_attn, mod_mlp = self.modulation(temb) attn_shift, attn_scale, attn_gate = mod_attn mlp_shift, mlp_scale, mlp_gate = mod_mlp - # Apply modulation and pre-normalization to image tokens - img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift + hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift - # Forward through PhotonAttention module attn_out = self.attention( - hidden_states=img_mod, - encoder_hidden_states=txt, + hidden_states=hidden_states_mod, + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, - image_rotary_emb=pe, + image_rotary_emb=image_rotary_emb, ) - img = img + attn_gate * attn_out + hidden_states = hidden_states + attn_gate * attn_out - # Inline FFN forward - x = (1 + mlp_scale) * self.post_attention_layernorm(img) + mlp_shift - img = img + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) - return img + x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift + hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) + return hidden_states class FinalLayer(nn.Module): @@ -749,10 +746,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): ) else: img = block( - img=img, - txt=txt, - vec=vec, - pe=pe, + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=pe, attention_mask=cross_attn_mask, )