diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 785a4b9135..36e9dd611e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -332,7 +332,8 @@ class ResnetBlock2D(nn.Module): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -350,7 +351,8 @@ class ResnetBlock2D(nn.Module): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 42b54657d2..e584520e94 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,7 +261,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # 6. post-process # make sure hidden states is in float32 # when running in half-precision - sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + # sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample)