1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Removing .float() (autocast in fp16 will discard this (I think)). (#495)

This commit is contained in:
Nicolas Patry
2022-09-14 08:20:27 +02:00
committed by GitHub
parent ab7a78e8f1
commit 7c4b38baca

View File

@@ -333,7 +333,7 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
@@ -351,7 +351,7 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)