mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
UNet Flax with FlaxModelMixin (#502)
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until #513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <dmishig@gmail.com> Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -64,6 +64,7 @@ else:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
|
||||
180
src/diffusers/models/attention_flax.py
Normal file
180
src/diffusers/models/attention_flax.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class FlaxAttentionBlock(nn.Module):
|
||||
query_dim: int
|
||||
heads: int = 8
|
||||
dim_head: int = 64
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
inner_dim = self.dim_head * self.heads
|
||||
self.scale = self.dim_head**-0.5
|
||||
|
||||
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
||||
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
|
||||
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
||||
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
||||
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
||||
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def __call__(self, hidden_states, context=None, deterministic=True):
|
||||
context = hidden_states if context is None else context
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(context)
|
||||
value_proj = self.value(context)
|
||||
|
||||
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
||||
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
||||
|
||||
# compute attentions
|
||||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
||||
attention_scores = attention_scores * self.scale
|
||||
attention_probs = nn.softmax(attention_scores, axis=2)
|
||||
|
||||
# attend to values
|
||||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBasicTransformerBlock(nn.Module):
|
||||
dim: int
|
||||
n_heads: int
|
||||
d_head: int
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
# self attention
|
||||
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
# cross attention
|
||||
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# cross attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# feed forward
|
||||
residual = hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxSpatialTransformer(nn.Module):
|
||||
in_channels: int
|
||||
n_heads: int
|
||||
d_head: int
|
||||
depth: int = 1
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
inner_dim = self.n_heads * self.d_head
|
||||
self.proj_in = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.transformer_blocks = [
|
||||
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
|
||||
self.proj_out = nn.Conv(
|
||||
inner_dim,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
# import ipdb; ipdb.set_trace()
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
||||
|
||||
for transformer_block in self.transformer_blocks:
|
||||
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxGluFeedForward(nn.Module):
|
||||
dim: int
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
|
||||
hidden_states = self.dense2(hidden_states)
|
||||
return hidden_states
|
||||
56
src/diffusers/models/embeddings_flax.py
Normal file
56
src/diffusers/models/embeddings_flax.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
return emb
|
||||
|
||||
|
||||
class FlaxTimestepEmbedding(nn.Module):
|
||||
time_embed_dim: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, temb):
|
||||
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
|
||||
temb = nn.silu(temb)
|
||||
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
|
||||
return temb
|
||||
|
||||
|
||||
class FlaxTimesteps(nn.Module):
|
||||
dim: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim)
|
||||
111
src/diffusers/models/resnet_flax.py
Normal file
111
src/diffusers/models/resnet_flax.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class FlaxUpsample2D(nn.Module):
|
||||
out_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
hidden_states = jax.image.resize(
|
||||
hidden_states,
|
||||
shape=(batch, height * 2, width * 2, channels),
|
||||
method="nearest",
|
||||
)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxDownsample2D(nn.Module):
|
||||
out_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(2, 2),
|
||||
padding=((1, 1), (1, 1)), # padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
||||
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxResnetBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout_prob: float = 0.0
|
||||
use_nin_shortcut: bool = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
self.conv1 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
self.dropout = nn.Dropout(self.dropout_prob)
|
||||
self.conv2 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
|
||||
|
||||
self.conv_shortcut = None
|
||||
if use_nin_shortcut:
|
||||
self.conv_shortcut = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, temb, deterministic=True):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = nn.swish(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
temb = self.time_emb_proj(nn.swish(temb))
|
||||
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = nn.swish(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
@@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
and returns sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
@@ -198,7 +198,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
"""r
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
258
src/diffusers/models/unet_2d_condition_flax.py
Normal file
258
src/diffusers/models/unet_2d_condition_flax.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .unet_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxCrossAttnUpBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
FlaxUpBlock2D,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxUNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: jnp.ndarray
|
||||
|
||||
|
||||
@flax_register_to_config
|
||||
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
r"""
|
||||
FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
|
||||
timestep and returns sample shaped output.
|
||||
|
||||
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
|
||||
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
|
||||
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
|
||||
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
|
||||
"""
|
||||
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: int = 8
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
||||
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
|
||||
|
||||
def setup(self):
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv(
|
||||
block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# time
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0])
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
# down
|
||||
down_blocks = []
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(self.down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if down_block_type == "CrossAttnDownBlock2D":
|
||||
down_block = FlaxCrossAttnDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
down_block = FlaxDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
down_blocks.append(down_block)
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# mid
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# up
|
||||
up_blocks = []
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(self.up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if up_block_type == "CrossAttnUpBlock2D":
|
||||
up_block = FlaxCrossAttnUpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
attn_num_head_channels=self.attention_head_dim,
|
||||
add_upsample=not is_final_block,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
up_block = FlaxUpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
add_upsample=not is_final_block,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
self.up_blocks = up_blocks
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
self.conv_out = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
||||
"""r
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
|
||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||
encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for down_block in self.down_blocks:
|
||||
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
||||
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
else:
|
||||
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
# 5. up
|
||||
for up_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
|
||||
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
|
||||
if isinstance(up_block, FlaxCrossAttnUpBlock2D):
|
||||
sample = up_block(
|
||||
sample,
|
||||
temb=t_emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
deterministic=not train,
|
||||
)
|
||||
else:
|
||||
sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = nn.silu(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return FlaxUNet2DConditionOutput(sample=sample)
|
||||
263
src/diffusers/models/unet_blocks_flax.py
Normal file
263
src/diffusers/models/unet_blocks_flax.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxSpatialTransformer
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
add_downsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
self.resnets = resnets
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class FlaxDownBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
add_downsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, deterministic=True):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
prev_output_channel: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
add_upsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
||||
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
self.resnets = resnets
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUpBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
prev_output_channel: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
add_upsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
||||
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
]
|
||||
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.in_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout_prob=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
self.resnets = resnets
|
||||
self.attentions = attentions
|
||||
|
||||
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
|
||||
return hidden_states
|
||||
@@ -46,6 +46,13 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxUNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user