diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index c9c07f191f..0f24e50369 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -340,7 +340,7 @@ if __name__ == "__main__": default=DEFAULT_RESOLUTION, help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", ) - + parser.add_argument( "--shift", type=float, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index b5a89d642d..b77e2f8d6f 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -227,6 +227,7 @@ class QKNorm(torch.nn.Module): k = self.key_norm(k) return q.to(v), k.to(v) + class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. @@ -339,8 +340,6 @@ class PhotonBlock(nn.Module): self.modulation = Modulation(hidden_size) - - def forward( self, img: Tensor,