1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

UNet Flax with FlaxModelMixin (#502)

* First UNet Flax modeling blocks.

Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.

* Remove FlaxUNet2DConfig class.

* ignore_for_config non-config args.

* Implement `FlaxModelMixin`

* Use new mixins for Flax UNet.

For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.

* Import `FlaxUNet2DConditionModel` if flax is available.

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Indicate types in flax.struct.dataclass as pointed out by @mishig25

Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>

* Fix typo in transformer block.

* make style

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Style: remove empty line.

* Apply style (after upgrading black from pinned version)

* Remove some commented code and unused imports.

* Add init_weights (not yet in use until #513).

* Trickle down deterministic to blocks.

* Rename q, k, v according to the latest PyTorch version.

Note that weights were exported with the old names, so we need to be
careful.

* Flax UNet docstrings, default props as in PyTorch.

* Fix minor typos in PyTorch docstrings.

* Use FlaxUNet2DConditionOutput as output from UNet.

* make style

Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca
2022-09-15 18:07:15 +02:00
committed by GitHub
parent fb5468a6aa
commit d8b0e4f433
8 changed files with 878 additions and 2 deletions

View File

@@ -64,6 +64,7 @@ else:
if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,

View File

@@ -0,0 +1,180 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax.numpy as jnp
class FlaxAttentionBlock(nn.Module):
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
# Weights were exported with old names {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context
query_proj = self.query(hidden_states)
key_proj = self.key(context)
value_proj = self.value(context)
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
class FlaxBasicTransformerBlock(nn.Module):
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual
# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
return hidden_states
class FlaxSpatialTransformer(nn.Module):
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class FlaxGluFeedForward(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states

View File

@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import flax.linen as nn
import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb
class FlaxTimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb
class FlaxTimesteps(nn.Module):
dim: int = 32
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)

View File

@@ -0,0 +1,111 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxDownsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)), # padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states):
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.dropout = nn.Dropout(self.dropout_prob)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, temb, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = self.time_emb_proj(nn.swish(temb))
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual

View File

@@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
@@ -198,7 +198,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

View File

@@ -0,0 +1,258 @@
from dataclasses import dataclass
from typing import Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .unet_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
FlaxUpBlock2D,
)
@dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: jnp.ndarray
@flax_register_to_config
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
timestep and returns sample shaped output.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
"""
sample_size: int = 32
in_channels: int = 4
out_channels: int = 4
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: int = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv(
block_out_channels[0],
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
# time
self.time_proj = FlaxTimesteps(block_out_channels[0])
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
# down
down_blocks = []
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(self.down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if down_block_type == "CrossAttnDownBlock2D":
down_block = FlaxCrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim,
add_downsample=not is_final_block,
dtype=self.dtype,
)
else:
down_block = FlaxDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
add_downsample=not is_final_block,
dtype=self.dtype,
)
down_blocks.append(down_block)
self.down_blocks = down_blocks
# mid
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim,
dtype=self.dtype,
)
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
if up_block_type == "CrossAttnUpBlock2D":
up_block = FlaxCrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim,
add_upsample=not is_final_block,
dropout=self.dropout,
dtype=self.dtype,
)
else:
up_block = FlaxUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
add_upsample=not is_final_block,
dropout=self.dropout,
dtype=self.dtype,
)
up_blocks.append(up_block)
prev_output_channel = output_channel
self.up_blocks = up_blocks
# out
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(
self,
sample,
timesteps,
encoder_hidden_states,
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
"""r
Args:
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
else:
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
# 5. up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
if isinstance(up_block, FlaxCrossAttnUpBlock2D):
sample = up_block(
sample,
temb=t_emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
deterministic=not train,
)
else:
sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return FlaxUNet2DConditionOutput(sample=sample)

View File

@@ -0,0 +1,263 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxSpatialTransformer
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxDownBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxCrossAttnUpBlock2D(nn.Module):
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
return hidden_states
class FlaxUpBlock2D(nn.Module):
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
return hidden_states
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
]
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxSpatialTransformer(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
self.attentions = attentions
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
return hidden_states

View File

@@ -46,6 +46,13 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]