mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Mochi-1] ensuring to compute the fourier features in FP32 in Mochi encoder (#10031)
compute fourier features in FP32.
This commit is contained in:
@@ -437,7 +437,8 @@ class FourierFeatures(nn.Module):
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `FourierFeatures` class."""
|
||||
|
||||
original_dtype = inputs.dtype
|
||||
inputs = inputs.to(torch.float32)
|
||||
num_channels = inputs.shape[1]
|
||||
num_freqs = (self.stop - self.start) // self.step
|
||||
|
||||
@@ -450,7 +451,7 @@ class FourierFeatures(nn.Module):
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
|
||||
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
|
||||
|
||||
|
||||
class MochiEncoder3D(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user