From c96bfa5c80eca798d555a79a491043c311d0f608 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 29 Nov 2024 14:15:00 +0530 Subject: [PATCH] [Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (#10031) compute fourier features in FP32. --- src/diffusers/models/autoencoders/autoencoder_kl_mochi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 0eabf3a26d..920b0b62fe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -437,7 +437,8 @@ class FourierFeatures(nn.Module): def forward(self, inputs: torch.Tensor) -> torch.Tensor: r"""Forward method of the `FourierFeatures` class.""" - + original_dtype = inputs.dtype + inputs = inputs.to(torch.float32) num_channels = inputs.shape[1] num_freqs = (self.stop - self.start) // self.step @@ -450,7 +451,7 @@ class FourierFeatures(nn.Module): # Scale channels by frequency. h = w * h - return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1) + return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype) class MochiEncoder3D(nn.Module):