1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add UNet 1d for RL model for planning + colab (#105)

* re-add RL model code

* match model forward api

* add register_to_config, pass training tests

* fix tests, update forward outputs

* remove unused code, some comments

* add to docs

* remove extra embedding code

* unify time embedding

* remove conv1d output sequential

* remove sequential from conv1dblock

* style and deleting duplicated code

* clean files

* remove unused variables

* clean variables

* add 1d resnet block structure for downsample

* rename as unet1d

* fix renaming

* rename files

* add get_block(...) api

* unify args for model1d like model2d

* minor cleaning

* fix docs

* improve 1d resnet blocks

* fix tests, remove permuts

* fix style

* add output activation

* rename flax blocks file

* Add Value Function and corresponding example script to Diffuser implementation (#884)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* update post merge of scripts

* add mdiblock / outblock architecture

* Pipeline cleanup (#947)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

* clean up comments

* convert older script to using pipeline and add readme

* rename scripts

* style, update tests

* delete unet rl model file

* remove imports in src

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* Update src/diffusers/models/unet_1d_blocks.py

* Update tests/test_models_unet.py

* RL Cleanup v2 (#965)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

* clean up comments

* convert older script to using pipeline and add readme

* rename scripts

* style, update tests

* delete unet rl model file

* remove imports in src

* add specific vf block and update tests

* style

* Update tests/test_models_unet.py

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* fix quality in tests

* fix quality style, split test file

* fix checks / tests

* make timesteps closer to main

* unify block API

* unify forward api

* delete lines in examples

* style

* examples style

* all tests pass

* make style

* make dance_diff test pass

* Refactoring RL PR (#1200)

* init file changes

* add import utils

* finish cleaning files, imports

* remove import flags

* clean examples

* fix imports, tests for merge

* update readmes

* hotfix for tests

* quality

* fix some tests

* change defaults

* more mps test fixes

* unet1d defaults

* do not default import experimental

* defaults for tests

* fix tests

* fix-copies

* fix

* changes per Patrik's comments (#1285)

* changes per Patrik's comments

* update conversion script

* fix renaming

* skip more mps tests

* last test fix

* Update examples/rl/README.md

Co-authored-by: Ben Glickenhaus <benglickenhaus@gmail.com>
This commit is contained in:
Nathan Lambert
2022-11-14 13:48:48 -08:00
committed by GitHub
parent a8d0977769
commit 7c5fef81e0
18 changed files with 1176 additions and 65 deletions

4
.gitignore vendored
View File

@@ -163,4 +163,6 @@ tags
*.lock
# DS_Store (MacOS)
.DS_Store
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4

View File

@@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet2DModel
[[autodoc]] UNet2DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
## UNet1DModel
[[autodoc]] UNet1DModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput

View File

@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
## Community

19
examples/rl/README.md Normal file
View File

@@ -0,0 +1,19 @@
# Overview
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
There are four scripts,
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
You will need some RL specific requirements to run the examples:
```
pip install -f https://download.pytorch.org/whl/torch_stable.html \
free-mujoco-py \
einops \
gym==0.24.1 \
protobuf==3.20.1 \
git+https://github.com/rail-berkeley/d4rl.git \
mediapy \
Pillow==9.0.0
```

View File

@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline
config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=0,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)
if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)
pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)
env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# Call the policy
denorm_actions = pipeline(obs, planning_horizon=32)
# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())
obs = next_observation
except KeyboardInterrupt:
pass
print(f"Total reward: {total_reward}")

View File

@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline
config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=2,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)
if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)
pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)
env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# call the policy
denorm_actions = pipeline(obs, planning_horizon=32)
# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())
obs = next_observation
except KeyboardInterrupt:
pass
print(f"Total reward: {total_reward}")

View File

@@ -0,0 +1,100 @@
import json
import os
import torch
from diffusers import UNet1DModel
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
def unet(hor):
if hor == 128:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
elif hor == 32:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 64, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
state_dict = model.state_dict()
config = dict(
down_block_types=down_block_types,
block_out_channels=block_out_channels,
up_block_types=up_block_types,
layers_per_block=1,
use_timestep_embedding=True,
out_block_type="OutConv1DBlock",
norm_num_groups=8,
downsample_each_block=False,
in_channels=14,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
flip_sin_to_cos=False,
freq_shift=1,
sample_size=65536,
mid_block_type="MidResTemporalBlock1D",
act_fn="mish",
)
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict)
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
json.dump(config, f)
def value_function():
config = dict(
in_channels=14,
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=(),
out_block_type="ValueFunction",
mid_block_type="ValueFunctionMidBlock1D",
block_out_channels=(32, 64, 128, 256),
layers_per_block=1,
downsample_each_block=True,
sample_size=65536,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
use_timestep_embedding=True,
flip_sin_to_cos=False,
freq_shift=1,
norm_num_groups=8,
act_fn="mish",
)
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
state_dict = model
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict)
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
json.dump(config, f)
if __name__ == "__main__":
unet(32)
# unet(128)
value_function()

View File

@@ -0,0 +1,5 @@
# 🧨 Diffusers Experimental
We are adding experimental code to support novel applications and usages of the Diffusers library.
Currently, the following experiments are supported:
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.

View File

@@ -0,0 +1 @@
from .rl import ValueGuidedRLPipeline

View File

@@ -0,0 +1 @@
from .value_guided_sampling import ValueGuidedRLPipeline

View File

@@ -0,0 +1,129 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import tqdm
from ...models.unet_1d import UNet1DModel
from ...pipeline_utils import DiffusionPipeline
from ...utils.dummy_pt_objects import DDPMScheduler
class ValueGuidedRLPipeline(DiffusionPipeline):
def __init__(
self,
value_function: UNet1DModel,
unet: UNet1DModel,
scheduler: DDPMScheduler,
env,
):
super().__init__()
self.value_function = value_function
self.unet = unet
self.scheduler = scheduler
self.env = env
self.data = env.get_dataset()
self.means = dict()
for key in self.data.keys():
try:
self.means[key] = self.data[key].mean()
except:
pass
self.stds = dict()
for key in self.data.keys():
try:
self.stds[key] = self.data[key].std()
except:
pass
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]
def normalize(self, x_in, key):
return (x_in - self.means[key]) / self.stds[key]
def de_normalize(self, x_in, key):
return x_in * self.stds[key] + self.means[key]
def to_torch(self, x_in):
if type(x_in) is dict:
return {k: self.to_torch(v) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(self.unet.device)
return torch.tensor(x_in, device=self.unet.device)
def reset_x0(self, x_in, cond, act_dim):
for key, val in cond.items():
x_in[:, key, act_dim:] = val.clone()
return x_in
def run_diffusion(self, x, conditions, n_guide_steps, scale):
batch_size = x.shape[0]
y = None
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]
posterior_variance = self.scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < 2] = 0
x = x.detach()
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
# normalize the observations and create batch dimension
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)
conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)
# run the diffusion process
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
# sort output trajectories by value
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")
# select the action with the highest value
if y is not None:
selected_index = 0
else:
# if we didn't run value guiding, select a random action
selected_index = np.random.randint(0, batch_size)
denorm_actions = denorm_actions[selected_index, 0]
return denorm_actions

View File

@@ -62,14 +62,21 @@ def get_timestep_embedding(
class TimestepEmbedding(nn.Module):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
elif act_fn == "mish":
self.act = nn.Mish()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample):
sample = self.linear_1(sample)

View File

@@ -5,6 +5,75 @@ import torch.nn as nn
import torch.nn.functional as F
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
use_conv_transpose:
out_channels:
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
out_channels:
padding:
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# unet_rl.py
def rearrange_dims(tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
def forward(self, x):
x = self.conv1d(x)
x = rearrange_dims(x)
x = self.group_norm(x)
x = rearrange_dims(x)
x = self.mish(x)
return x
# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
self.time_emb_act = nn.Mish()
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
t = self.time_emb_act(t)
t = self.time_emb(t)
out = self.conv_in(x) + rearrange_dims(t)
out = self.conv_out(out)
return out + self.residual_conv(x)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given

View File

@@ -1,3 +1,17 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@dataclass
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False:
experimental feature for using a UNet without upsampling.
"""
@register_to_config
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
mid_block_type: str = "UNetMidBlock1D",
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size
# time
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type=mid_block_type,
mid_channels=block_out_channels[-1],
mid_block_type,
in_channels=block_out_channels[-1],
out_channels=None,
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
# Totally fine to add another layer with a if statement - no need for nn.Identity here
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward(
self,
@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if len(timestep.shape) == 0:
timestep = timestep[None]
timestep_embed = self.time_proj(timestep)[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 2. down
down_block_res_samples = ()
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 3. mid
sample = self.mid_block(sample)
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_samples)
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)

View File

@@ -17,6 +17,256 @@ import torch
import torch.nn.functional as F
from torch import nn
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
conv_shortcut=False,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_downsample=True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states, temb=None):
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x, temb=None):
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dim,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states, temb):
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
if act_fn == "silu":
self.final_conv1d_act = nn.SiLU()
if act_fn == "mish":
self.final_conv1d_act = nn.Mish()
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None):
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states, temb):
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states):
def forward(self, hidden_states, temb=None):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
return output
def get_down_block(down_block_type, out_channels, in_channels):
if down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, in_channels, out_channels):
if up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
if mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None):
super().__init__()
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states):
def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
hidden_states = resnet(hidden_states)
return hidden_states
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim)
return None

View File

@@ -204,6 +204,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
@@ -301,7 +302,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
else:
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance

View File

@@ -18,13 +18,120 @@ import unittest
import torch
from diffusers import UNet1DModel
from diffusers.utils import slow, torch_device
from diffusers.utils import floats_tensor, slow, torch_device
from ..test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class UnetModel1DTests(unittest.TestCase):
class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (4, 14, 16)
def test_ema_training(self):
pass
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_config(self):
super().test_model_from_config()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64, 128, 256),
"in_channels": 14,
"out_channels": 14,
"time_embedding_type": "positional",
"use_timestep_embedding": True,
"flip_sin_to_cos": False,
"freq_shift": 1.0,
"out_block_type": "OutConv1DBlock",
"mid_block_type": "MidResTemporalBlock1D",
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = model.in_channels
seq_len = 16
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
) # match original, we can update values and remove
time_step = torch.full((num_features,), 0)
with torch.no_grad():
output = model(noise, time_step).sample.permute(0, 2, 1)
output_slice = output[0, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
assert (output_sum - 224.0896).abs() < 4e-2
assert (output_max - 0.0607).abs() < 4e-4
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (4, 14, 1)
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_config(self):
super().test_model_from_config()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ema_training(self):
pass
def test_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
"up_block_types": [],
"out_block_type": "ValueFunction",
"mid_block_type": "ValueFunctionMidBlock1D",
"block_out_channels": [32, 64, 128, 256],
"layers_per_block": 1,
"downsample_each_block": True,
"use_timestep_embedding": True,
"freq_shift": 1.0,
"flip_sin_to_cos": False,
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
value_function.to(torch_device)
image = value_function(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = value_function.in_channels
seq_len = 14
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
) # match original, we can update values and remove
time_step = torch.full((num_features,), 0)
with torch.no_grad():
output = value_function(noise, time_step).sample
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass

View File

@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
sample_rate=16_000,
in_channels=2,
out_channels=2,
flip_sin_to_cos=True,
use_timestep_embedding=False,
time_embedding_type="fourier",
mid_block_type="UNetMidBlock1D",
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
)