1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

feat: add Dropout to Flax UNet (#3894)

* feat: add Dropout to Flax UNet

* feat: add @compact decorator

* fix: drop nn.compact
This commit is contained in:
Saurav Maheshkar
2023-07-07 15:08:16 +05:30
committed by GitHub
parent 8d8b4311b9
commit 03d829d59e

View File

@@ -152,6 +152,7 @@ class FlaxAttention(nn.Module):
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
self.dropout_layer = nn.Dropout(rate=self.dropout)
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@@ -214,7 +215,7 @@ class FlaxAttention(nn.Module):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxBasicTransformerBlock(nn.Module):
@@ -260,6 +261,7 @@ class FlaxBasicTransformerBlock(nn.Module):
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, context, deterministic=True):
# self attention
@@ -280,7 +282,7 @@ class FlaxBasicTransformerBlock(nn.Module):
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxTransformer2DModel(nn.Module):
@@ -356,6 +358,8 @@ class FlaxTransformer2DModel(nn.Module):
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
@@ -378,7 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxFeedForward(nn.Module):
@@ -409,7 +413,7 @@ class FlaxFeedForward(nn.Module):
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
hidden_states = self.net_2(hidden_states)
return hidden_states
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
def setup(self):
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return hidden_linear * nn.gelu(hidden_gelu)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)