1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (#10031)

compute fourier features in FP32.
This commit is contained in:
Sayak Paul
2024-11-29 14:15:00 +05:30
committed by GitHub
parent 6b288ec44d
commit c96bfa5c80

View File

@@ -437,7 +437,8 @@ class FourierFeatures(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `FourierFeatures` class."""
original_dtype = inputs.dtype
inputs = inputs.to(torch.float32)
num_channels = inputs.shape[1]
num_freqs = (self.stop - self.start) // self.step
@@ -450,7 +451,7 @@ class FourierFeatures(nn.Module):
# Scale channels by frequency.
h = w * h
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
class MochiEncoder3D(nn.Module):