1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Replace dropout_prob by dropout in vae (#595)

replace `dropout_prob` by `dropout` in `vae`
This commit is contained in:
Younes Belkada
2022-09-21 11:43:28 +02:00
committed by GitHub
parent 8685699392
commit 3fc8ef7297

View File

@@ -89,7 +89,7 @@ class FlaxDownsample2D(nn.Module):
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
dropout: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
@@ -106,7 +106,7 @@ class FlaxResnetBlock2D(nn.Module):
)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.dropout = nn.Dropout(self.dropout_prob)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
@@ -135,7 +135,7 @@ class FlaxResnetBlock2D(nn.Module):
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.dropout_layer(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
@@ -217,7 +217,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -251,7 +251,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -284,7 +284,7 @@ class FlaxUNetMidBlock2D(nn.Module):
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
]
@@ -300,7 +300,7 @@ class FlaxUNetMidBlock2D(nn.Module):
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)