1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove modulation dataclass

This commit is contained in:
David Bertoin
2025-10-13 11:29:43 +00:00
committed by DavidBert
parent 5f0bf0181f
commit af8882d7e6

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
@@ -228,29 +227,20 @@ class QKNorm(torch.nn.Module):
k = self.key_norm(k)
return q.to(v), k.to(v)
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
r"""
Modulation network that generates scale, shift, and gating parameters.
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two `ModulationOut` objects.
two tuples `(shift, scale, gate)`.
Parameters:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.
Returns:
(`ModulationOut`, `ModulationOut`):
A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift,
gate).
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
Two tuples `(shift, scale, gate)`.
"""
def __init__(self, dim: int):
@@ -259,9 +249,9 @@ class Modulation(nn.Module):
nn.init.constant_(self.lin.weight, 0)
nn.init.constant_(self.lin.bias, 0)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
return ModulationOut(*out[:3]), ModulationOut(*out[3:])
return tuple(out[:3]), tuple(out[3:])
class PhotonBlock(nn.Module):
@@ -301,7 +291,7 @@ class PhotonBlock(nn.Module):
modulation (`Modulation`):
Produces scale/shift/gating parameters for modulated layers.
Methods:
Methods:
attn_forward(img, txt, pe, modulation, attention_mask=None):
Compute cross-attention between image and text tokens, with optional attention masking.
@@ -312,8 +302,8 @@ class PhotonBlock(nn.Module):
Text tokens of shape `(B, L_txt, hidden_size)`.
pe (`torch.Tensor`):
Rotary positional embeddings to apply to queries and keys.
modulation (`ModulationOut`):
Scale and shift parameters for modulating image tokens.
modulation ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
Tuple `(shift, scale, gate)` for modulating image tokens.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)` where 0 marks padding.
@@ -371,11 +361,12 @@ class PhotonBlock(nn.Module):
img: Tensor,
txt: Tensor,
pe: Tensor,
modulation: ModulationOut,
modulation: tuple[Tensor, Tensor, Tensor],
attention_mask: None | Tensor = None,
) -> Tensor:
# image tokens proj and norm
img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
shift, scale, _gate = modulation
img_mod = (1 + scale) * self.img_pre_norm(img) + shift
img_qkv = self.img_qkv_proj(img_mod)
# Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
@@ -433,8 +424,9 @@ class PhotonBlock(nn.Module):
return attn
def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor:
x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift
def _ffn_forward(self, x: Tensor, modulation: tuple[Tensor, Tensor, Tensor]) -> Tensor:
shift, scale, _gate = modulation
x = (1 + scale) * self.post_attention_layernorm(x) + shift
return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))
def forward(
@@ -470,15 +462,17 @@ class PhotonBlock(nn.Module):
"""
mod_attn, mod_mlp = self.modulation(vec)
attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp
img = img + mod_attn.gate * self._attn_forward(
img = img + attn_gate * self._attn_forward(
img,
txt,
pe,
mod_attn,
(attn_shift, attn_scale, attn_gate),
attention_mask=attention_mask,
)
img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp)
img = img + mlp_gate * self._ffn_forward(img, (mlp_shift, mlp_scale, mlp_gate))
return img