From 7c4b38baca6f8c3bbd24c8b458dcd2b507efa129 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Sep 2022 08:20:27 +0200 Subject: [PATCH] Removing `.float()` (`autocast` in fp16 will discard this (I think)). (#495) --- src/diffusers/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 27fae24f71..507ca8632d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -333,7 +333,7 @@ class ResnetBlock2D(nn.Module): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -351,7 +351,7 @@ class ResnetBlock2D(nn.Module): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states)