mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
parameter names match the standard diffusers conventions
This commit is contained in:
@@ -411,57 +411,54 @@ class PhotonBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor,
|
||||
pe: Tensor,
|
||||
hidden_states: Tensor,
|
||||
encoder_hidden_states: Tensor,
|
||||
temb: Tensor,
|
||||
image_rotary_emb: Tensor,
|
||||
attention_mask: Tensor | None = None,
|
||||
**_: dict[str, Any],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Runs modulation-gated cross-attention and MLP, with residual connections.
|
||||
|
||||
Parameters:
|
||||
img (`torch.Tensor`):
|
||||
hidden_states (`torch.Tensor`):
|
||||
Image tokens of shape `(B, L_img, hidden_size)`.
|
||||
txt (`torch.Tensor`):
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text tokens of shape `(B, L_txt, hidden_size)`.
|
||||
vec (`torch.Tensor`):
|
||||
temb (`torch.Tensor`):
|
||||
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
|
||||
broadcastable).
|
||||
pe (`torch.Tensor`):
|
||||
image_rotary_emb (`torch.Tensor`):
|
||||
Rotary positional embeddings applied inside attention.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
|
||||
**_:
|
||||
Ignored additional keyword arguments for API compatibility.
|
||||
**kwargs:
|
||||
Additional keyword arguments for API compatibility.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Updated image tokens of shape `(B, L_img, hidden_size)`.
|
||||
"""
|
||||
|
||||
mod_attn, mod_mlp = self.modulation(vec)
|
||||
mod_attn, mod_mlp = self.modulation(temb)
|
||||
attn_shift, attn_scale, attn_gate = mod_attn
|
||||
mlp_shift, mlp_scale, mlp_gate = mod_mlp
|
||||
|
||||
# Apply modulation and pre-normalization to image tokens
|
||||
img_mod = (1 + attn_scale) * self.img_pre_norm(img) + attn_shift
|
||||
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
|
||||
|
||||
# Forward through PhotonAttention module
|
||||
attn_out = self.attention(
|
||||
hidden_states=img_mod,
|
||||
encoder_hidden_states=txt,
|
||||
hidden_states=hidden_states_mod,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=pe,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
img = img + attn_gate * attn_out
|
||||
hidden_states = hidden_states + attn_gate * attn_out
|
||||
|
||||
# Inline FFN forward
|
||||
x = (1 + mlp_scale) * self.post_attention_layernorm(img) + mlp_shift
|
||||
img = img + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
|
||||
return img
|
||||
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
|
||||
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
@@ -749,10 +746,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
)
|
||||
else:
|
||||
img = block(
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=pe,
|
||||
attention_mask=cross_attn_mask,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user