From af8882d7e6be522b341f527a6ed865f016be82f8 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Mon, 13 Oct 2025 11:29:43 +0000 Subject: [PATCH] remove modulation dataclass --- .../models/transformers/transformer_photon.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index c7b5ca5186..46565fd1d7 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch @@ -228,29 +227,20 @@ class QKNorm(torch.nn.Module): k = self.key_norm(k) return q.to(v), k.to(v) - -@dataclass -class ModulationOut: - shift: Tensor - scale: Tensor - gate: Tensor - - class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into - two `ModulationOut` objects. + two tuples `(shift, scale, gate)`. Parameters: dim (`int`): Dimensionality of the input vector. The output will have `6 * dim` features internally. Returns: - (`ModulationOut`, `ModulationOut`): - A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift, - gate). + ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Two tuples `(shift, scale, gate)`. """ def __init__(self, dim: int): @@ -259,9 +249,9 @@ class Modulation(nn.Module): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) - return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + return tuple(out[:3]), tuple(out[3:]) class PhotonBlock(nn.Module): @@ -301,7 +291,7 @@ class PhotonBlock(nn.Module): modulation (`Modulation`): Produces scale/shift/gating parameters for modulated layers. - Methods: + Methods: attn_forward(img, txt, pe, modulation, attention_mask=None): Compute cross-attention between image and text tokens, with optional attention masking. @@ -312,8 +302,8 @@ class PhotonBlock(nn.Module): Text tokens of shape `(B, L_txt, hidden_size)`. pe (`torch.Tensor`): Rotary positional embeddings to apply to queries and keys. - modulation (`ModulationOut`): - Scale and shift parameters for modulating image tokens. + modulation ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Tuple `(shift, scale, gate)` for modulating image tokens. attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)` where 0 marks padding. @@ -371,11 +361,12 @@ class PhotonBlock(nn.Module): img: Tensor, txt: Tensor, pe: Tensor, - modulation: ModulationOut, + modulation: tuple[Tensor, Tensor, Tensor], attention_mask: None | Tensor = None, ) -> Tensor: # image tokens proj and norm - img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + shift, scale, _gate = modulation + img_mod = (1 + scale) * self.img_pre_norm(img) + shift img_qkv = self.img_qkv_proj(img_mod) # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -433,8 +424,9 @@ class PhotonBlock(nn.Module): return attn - def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: - x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + def _ffn_forward(self, x: Tensor, modulation: tuple[Tensor, Tensor, Tensor]) -> Tensor: + shift, scale, _gate = modulation + x = (1 + scale) * self.post_attention_layernorm(x) + shift return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) def forward( @@ -470,15 +462,17 @@ class PhotonBlock(nn.Module): """ mod_attn, mod_mlp = self.modulation(vec) + attn_shift, attn_scale, attn_gate = mod_attn + mlp_shift, mlp_scale, mlp_gate = mod_mlp - img = img + mod_attn.gate * self._attn_forward( + img = img + attn_gate * self._attn_forward( img, txt, pe, - mod_attn, + (attn_shift, attn_scale, attn_gate), attention_mask=attention_mask, ) - img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) + img = img + mlp_gate * self._ffn_forward(img, (mlp_shift, mlp_scale, mlp_gate)) return img