1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove restriction to run conv_norm in fp32

This commit is contained in:
Nouamane Tazi
2022-09-21 13:38:41 +00:00
parent 4e67675b89
commit cec592890c
2 changed files with 6 additions and 3 deletions

View File

@@ -332,7 +332,8 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
# hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
@@ -350,7 +351,8 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
# hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)

View File

@@ -261,7 +261,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
# sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype)
sample = self.conv_norm_out(sample).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)