diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 1a40a82971..c5809bc2c0 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -340,7 +340,7 @@ class Modulation(nn.Module): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: + def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) return tuple(out[:3]), tuple(out[3:])