mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: add Dropout to Flax UNet (#3894)
* feat: add Dropout to Flax UNet * feat: add @compact decorator * fix: drop nn.compact
This commit is contained in:
@@ -152,6 +152,7 @@ class FlaxAttention(nn.Module):
|
||||
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
||||
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -214,7 +215,7 @@ class FlaxAttention(nn.Module):
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxBasicTransformerBlock(nn.Module):
|
||||
@@ -260,6 +261,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
@@ -280,7 +282,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxTransformer2DModel(nn.Module):
|
||||
@@ -356,6 +358,8 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
residual = hidden_states
|
||||
@@ -378,7 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxFeedForward(nn.Module):
|
||||
@@ -409,7 +413,7 @@ class FlaxFeedForward(nn.Module):
|
||||
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.net_0(hidden_states)
|
||||
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.net_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
return hidden_linear * nn.gelu(hidden_gelu)
|
||||
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
||||
|
||||
Reference in New Issue
Block a user