diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 49ff7d6bfa..6b0089d5c2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -335,8 +335,6 @@ class ResnetBlock2D(nn.Module): def forward(self, x, temb): hidden_states = x - # make sure hidden states is in float32 - # when running in half-precision hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) @@ -353,8 +351,6 @@ class ResnetBlock2D(nn.Module): temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] hidden_states = hidden_states + temb - # make sure hidden states is in float32 - # when running in half-precision hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 04453e0645..dd9e2e570b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -313,8 +313,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample)