1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

parameter names match the standard diffusers conventions

This commit is contained in:
DavidBert
2025-10-16 08:34:52 +00:00
parent 6e05172682
commit 574f8fd10a

View File

@@ -411,57 +411,54 @@ class PhotonBlock(nn.Module):
def forward(
self,
img: Tensor,
txt: Tensor,
vec: Tensor,
pe: Tensor,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
temb: Tensor,
image_rotary_emb: Tensor,
attention_mask: Tensor | None = None,
**_: dict[str, Any],
**kwargs: dict[str, Any],
) -> Tensor:
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
Parameters:
img (`torch.Tensor`):
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
txt (`torch.Tensor`):
encoder_hidden_states (`torch.Tensor`):
Text tokens of shape `(B, L_txt, hidden_size)`.
vec (`torch.Tensor`):
temb (`torch.Tensor`):
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
broadcastable).
pe (`torch.Tensor`):
image_rotary_emb (`torch.Tensor`):
Rotary positional embeddings applied inside attention.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
**_:
Ignored additional keyword arguments for API compatibility.
**kwargs:
Additional keyword arguments for API compatibility.
Returns:
`torch.Tensor`:
Updated image tokens of shape `(B, L_img, hidden_size)`.
"""
mod_attn, mod_mlp = self.modulation(vec)
mod_attn, mod_mlp = self.modulation(temb)
attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp
# Apply modulation and pre-normalization to image tokens
img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
# Forward through PhotonAttention module
attn_out = self.attention(
hidden_states=img_mod,
encoder_hidden_states=txt,
hidden_states=hidden_states_mod,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=pe,
image_rotary_emb=image_rotary_emb,
)
img = img + attn_gate * attn_out
hidden_states = hidden_states + attn_gate * attn_out
# Inline FFN forward
x = (1 + mlp_scale) * self.post_attention_layernorm(img) + mlp_shift
img = img + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
return img
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
return hidden_states
class FinalLayer(nn.Module):
@@ -749,10 +746,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
)
else:
img = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=cross_attn_mask,
)