mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Replace dropout_prob by dropout in vae (#595)
replace `dropout_prob` by `dropout` in `vae`
This commit is contained in:
@@ -89,7 +89,7 @@ class FlaxDownsample2D(nn.Module):
|
||||
class FlaxResnetBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout_prob: float = 0.0
|
||||
dropout: float = 0.0
|
||||
use_nin_shortcut: bool = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -106,7 +106,7 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.dropout = nn.Dropout(self.dropout_prob)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
self.conv2 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -135,7 +135,7 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = nn.swish(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic)
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
@@ -217,7 +217,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
@@ -251,7 +251,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
@@ -284,7 +284,7 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
]
|
||||
@@ -300,7 +300,7 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
Reference in New Issue
Block a user