mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove modulation dataclass
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -228,29 +227,20 @@ class QKNorm(torch.nn.Module):
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
r"""
|
||||
Modulation network that generates scale, shift, and gating parameters.
|
||||
|
||||
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
|
||||
two `ModulationOut` objects.
|
||||
two tuples `(shift, scale, gate)`.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
Dimensionality of the input vector. The output will have `6 * dim` features internally.
|
||||
|
||||
Returns:
|
||||
(`ModulationOut`, `ModulationOut`):
|
||||
A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift,
|
||||
gate).
|
||||
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
|
||||
Two tuples `(shift, scale, gate)`.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
@@ -259,9 +249,9 @@ class Modulation(nn.Module):
|
||||
nn.init.constant_(self.lin.weight, 0)
|
||||
nn.init.constant_(self.lin.bias, 0)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
|
||||
def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
|
||||
return ModulationOut(*out[:3]), ModulationOut(*out[3:])
|
||||
return tuple(out[:3]), tuple(out[3:])
|
||||
|
||||
|
||||
class PhotonBlock(nn.Module):
|
||||
@@ -301,7 +291,7 @@ class PhotonBlock(nn.Module):
|
||||
modulation (`Modulation`):
|
||||
Produces scale/shift/gating parameters for modulated layers.
|
||||
|
||||
Methods:
|
||||
Methods:
|
||||
attn_forward(img, txt, pe, modulation, attention_mask=None):
|
||||
Compute cross-attention between image and text tokens, with optional attention masking.
|
||||
|
||||
@@ -312,8 +302,8 @@ class PhotonBlock(nn.Module):
|
||||
Text tokens of shape `(B, L_txt, hidden_size)`.
|
||||
pe (`torch.Tensor`):
|
||||
Rotary positional embeddings to apply to queries and keys.
|
||||
modulation (`ModulationOut`):
|
||||
Scale and shift parameters for modulating image tokens.
|
||||
modulation ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
|
||||
Tuple `(shift, scale, gate)` for modulating image tokens.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask of shape `(B, L_txt)` where 0 marks padding.
|
||||
|
||||
@@ -371,11 +361,12 @@ class PhotonBlock(nn.Module):
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: Tensor,
|
||||
modulation: ModulationOut,
|
||||
modulation: tuple[Tensor, Tensor, Tensor],
|
||||
attention_mask: None | Tensor = None,
|
||||
) -> Tensor:
|
||||
# image tokens proj and norm
|
||||
img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
|
||||
shift, scale, _gate = modulation
|
||||
img_mod = (1 + scale) * self.img_pre_norm(img) + shift
|
||||
|
||||
img_qkv = self.img_qkv_proj(img_mod)
|
||||
# Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
@@ -433,8 +424,9 @@ class PhotonBlock(nn.Module):
|
||||
|
||||
return attn
|
||||
|
||||
def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor:
|
||||
x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift
|
||||
def _ffn_forward(self, x: Tensor, modulation: tuple[Tensor, Tensor, Tensor]) -> Tensor:
|
||||
shift, scale, _gate = modulation
|
||||
x = (1 + scale) * self.post_attention_layernorm(x) + shift
|
||||
return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
def forward(
|
||||
@@ -470,15 +462,17 @@ class PhotonBlock(nn.Module):
|
||||
"""
|
||||
|
||||
mod_attn, mod_mlp = self.modulation(vec)
|
||||
attn_shift, attn_scale, attn_gate = mod_attn
|
||||
mlp_shift, mlp_scale, mlp_gate = mod_mlp
|
||||
|
||||
img = img + mod_attn.gate * self._attn_forward(
|
||||
img = img + attn_gate * self._attn_forward(
|
||||
img,
|
||||
txt,
|
||||
pe,
|
||||
mod_attn,
|
||||
(attn_shift, attn_scale, attn_gate),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp)
|
||||
img = img + mlp_gate * self._ffn_forward(img, (mlp_shift, mlp_scale, mlp_gate))
|
||||
return img
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user