mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Dance Diffusion] Add dance diffusion (#803)
* start * add more logic * Update src/diffusers/models/unet_2d_condition_flax.py * match weights * up * make model work * making class more general, fixing missed file rename * small fix * make new conversion work * up * finalize conversion * up * first batch of variable renamings * remove c and c_prev var names * add mid and out block structure * add pipeline * up * finish conversion * finish * upload * more fixes * Apply suggestions from code review * add attr * up * uP * up * finish tests * finish * uP * finish * fix test * up * naming consistency in tests * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Nathan Lambert <nathan@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co> * remove hardcoded 16 * Remove bogus * fix some stuff * finish * improve logging * docs * upload Co-authored-by: Nathan Lambert <nol@berkeley.edu> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Nathan Lambert <nathan@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
0b42b074b4
commit
88fa6b7d68
@@ -92,5 +92,7 @@
|
||||
title: "Stable Diffusion"
|
||||
- local: api/pipelines/stochastic_karras_ve
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: "Dance Diffusion"
|
||||
title: "Pipelines"
|
||||
title: "API"
|
||||
|
||||
@@ -22,6 +22,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
|
||||
## UNet1DModel
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet2DModel
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
|
||||
33
docs/source/api/pipelines/dance_diffusion.mdx
Normal file
33
docs/source/api/pipelines/dance_diffusion.mdx
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Dance Diffusion
|
||||
|
||||
## Overview
|
||||
|
||||
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) by Zach Evans.
|
||||
|
||||
Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai.
|
||||
For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page.
|
||||
|
||||
The original codebase of this implementation can be found [here](https://github.com/Harmonai-org/sample-generator).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_dance_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py) | *Unconditional Audio Generation* | - |
|
||||
|
||||
|
||||
## DanceDiffusionPipeline
|
||||
[[autodoc]] DanceDiffusionPipeline
|
||||
- __call__
|
||||
@@ -95,6 +95,10 @@ Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
[[autodoc]] ScoreSdeVeScheduler
|
||||
|
||||
#### improved pseudo numerical methods for diffusion models (iPNDM)
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).
|
||||
|
||||
#### variance preserving stochastic differential equation (SDE) scheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
339
scripts/convert_dance_diffusion_to_diffusers.py
Executable file
339
scripts/convert_dance_diffusion_to_diffusers.py
Executable file
@@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from audio_diffusion.models import DiffusionAttnUnet1D
|
||||
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
||||
from diffusion import sampling
|
||||
|
||||
|
||||
MODELS_MAP = {
|
||||
"gwf-440k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
|
||||
"sample_rate": 48000,
|
||||
"sample_size": 65536,
|
||||
},
|
||||
"jmann-small-190k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
|
||||
"sample_rate": 48000,
|
||||
"sample_size": 65536,
|
||||
},
|
||||
"jmann-large-580k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
|
||||
"sample_rate": 48000,
|
||||
"sample_size": 131072,
|
||||
},
|
||||
"maestro-uncond-150k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
|
||||
"sample_rate": 16000,
|
||||
"sample_size": 65536,
|
||||
},
|
||||
"unlocked-uncond-250k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
|
||||
"sample_rate": 16000,
|
||||
"sample_size": 65536,
|
||||
},
|
||||
"honk-140k": {
|
||||
"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
|
||||
"sample_rate": 16000,
|
||||
"sample_size": 65536,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def alpha_sigma_to_t(alpha, sigma):
|
||||
"""Returns a timestep, given the scaling factors for the clean image and for
|
||||
the noise."""
|
||||
return torch.atan2(sigma, alpha) / math.pi * 2
|
||||
|
||||
|
||||
def get_crash_schedule(t):
|
||||
sigma = torch.sin(t * math.pi / 2) ** 2
|
||||
alpha = (1 - sigma**2) ** 0.5
|
||||
return alpha_sigma_to_t(alpha, sigma)
|
||||
|
||||
|
||||
class Object(object):
|
||||
pass
|
||||
|
||||
|
||||
class DiffusionUncond(nn.Module):
|
||||
def __init__(self, global_args):
|
||||
super().__init__()
|
||||
|
||||
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
|
||||
self.diffusion_ema = deepcopy(self.diffusion)
|
||||
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||
|
||||
|
||||
def download(model_name):
|
||||
url = MODELS_MAP[model_name]["url"]
|
||||
os.system(f"wget {url} ./")
|
||||
|
||||
return f"./{model_name}.ckpt"
|
||||
|
||||
|
||||
DOWN_NUM_TO_LAYER = {
|
||||
"1": "resnets.0",
|
||||
"2": "attentions.0",
|
||||
"3": "resnets.1",
|
||||
"4": "attentions.1",
|
||||
"5": "resnets.2",
|
||||
"6": "attentions.2",
|
||||
}
|
||||
UP_NUM_TO_LAYER = {
|
||||
"8": "resnets.0",
|
||||
"9": "attentions.0",
|
||||
"10": "resnets.1",
|
||||
"11": "attentions.1",
|
||||
"12": "resnets.2",
|
||||
"13": "attentions.2",
|
||||
}
|
||||
MID_NUM_TO_LAYER = {
|
||||
"1": "resnets.0",
|
||||
"2": "attentions.0",
|
||||
"3": "resnets.1",
|
||||
"4": "attentions.1",
|
||||
"5": "resnets.2",
|
||||
"6": "attentions.2",
|
||||
"8": "resnets.3",
|
||||
"9": "attentions.3",
|
||||
"10": "resnets.4",
|
||||
"11": "attentions.4",
|
||||
"12": "resnets.5",
|
||||
"13": "attentions.5",
|
||||
}
|
||||
DEPTH_0_TO_LAYER = {
|
||||
"0": "resnets.0",
|
||||
"1": "resnets.1",
|
||||
"2": "resnets.2",
|
||||
"4": "resnets.0",
|
||||
"5": "resnets.1",
|
||||
"6": "resnets.2",
|
||||
}
|
||||
|
||||
RES_CONV_MAP = {
|
||||
"skip": "conv_skip",
|
||||
"main.0": "conv_1",
|
||||
"main.1": "group_norm_1",
|
||||
"main.3": "conv_2",
|
||||
"main.4": "group_norm_2",
|
||||
}
|
||||
|
||||
ATTN_MAP = {
|
||||
"norm": "group_norm",
|
||||
"qkv_proj": ["query", "key", "value"],
|
||||
"out_proj": ["proj_attn"],
|
||||
}
|
||||
|
||||
|
||||
def convert_resconv_naming(name):
|
||||
if name.startswith("skip"):
|
||||
return name.replace("skip", RES_CONV_MAP["skip"])
|
||||
|
||||
# name has to be of format main.{digit}
|
||||
if not name.startswith("main."):
|
||||
raise ValueError(f"ResConvBlock error with {name}")
|
||||
|
||||
return name.replace(name[:6], RES_CONV_MAP[name[:6]])
|
||||
|
||||
|
||||
def convert_attn_naming(name):
|
||||
for key, value in ATTN_MAP.items():
|
||||
if name.startswith(key) and not isinstance(value, list):
|
||||
return name.replace(key, value)
|
||||
elif name.startswith(key):
|
||||
return [name.replace(key, v) for v in value]
|
||||
raise ValueError(f"Attn error with {name}")
|
||||
|
||||
|
||||
def rename(input_string, max_depth=13):
|
||||
string = input_string
|
||||
|
||||
if string.split(".")[0] == "timestep_embed":
|
||||
return string.replace("timestep_embed", "time_proj")
|
||||
|
||||
depth = 0
|
||||
if string.startswith("net.3."):
|
||||
depth += 1
|
||||
string = string[6:]
|
||||
elif string.startswith("net."):
|
||||
string = string[4:]
|
||||
|
||||
while string.startswith("main.7."):
|
||||
depth += 1
|
||||
string = string[7:]
|
||||
|
||||
if string.startswith("main."):
|
||||
string = string[5:]
|
||||
|
||||
# mid block
|
||||
if string[:2].isdigit():
|
||||
layer_num = string[:2]
|
||||
string_left = string[2:]
|
||||
else:
|
||||
layer_num = string[0]
|
||||
string_left = string[1:]
|
||||
|
||||
if depth == max_depth:
|
||||
new_layer = MID_NUM_TO_LAYER[layer_num]
|
||||
prefix = "mid_block"
|
||||
elif depth > 0 and int(layer_num) < 7:
|
||||
new_layer = DOWN_NUM_TO_LAYER[layer_num]
|
||||
prefix = f"down_blocks.{depth}"
|
||||
elif depth > 0 and int(layer_num) > 7:
|
||||
new_layer = UP_NUM_TO_LAYER[layer_num]
|
||||
prefix = f"up_blocks.{max_depth - depth - 1}"
|
||||
elif depth == 0:
|
||||
new_layer = DEPTH_0_TO_LAYER[layer_num]
|
||||
prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
|
||||
|
||||
if not string_left.startswith("."):
|
||||
raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
|
||||
|
||||
string_left = string_left[1:]
|
||||
|
||||
if "resnets" in new_layer:
|
||||
string_left = convert_resconv_naming(string_left)
|
||||
elif "attentions" in new_layer:
|
||||
new_string_left = convert_attn_naming(string_left)
|
||||
string_left = new_string_left
|
||||
|
||||
if not isinstance(string_left, list):
|
||||
new_string = prefix + "." + new_layer + "." + string_left
|
||||
else:
|
||||
new_string = [prefix + "." + new_layer + "." + s for s in string_left]
|
||||
return new_string
|
||||
|
||||
|
||||
def rename_orig_weights(state_dict):
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.endswith("kernel"):
|
||||
# up- and downsample layers, don't have trainable weights
|
||||
continue
|
||||
|
||||
new_k = rename(k)
|
||||
|
||||
# check if we need to transform from Conv => Linear for attention
|
||||
if isinstance(new_k, list):
|
||||
new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
|
||||
else:
|
||||
new_state_dict[new_k] = v
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def transform_conv_attns(new_state_dict, new_k, v):
|
||||
if len(new_k) == 1:
|
||||
if len(v.shape) == 3:
|
||||
# weight
|
||||
new_state_dict[new_k[0]] = v[:, :, 0]
|
||||
else:
|
||||
# bias
|
||||
new_state_dict[new_k[0]] = v
|
||||
else:
|
||||
# qkv matrices
|
||||
trippled_shape = v.shape[0]
|
||||
single_shape = trippled_shape // 3
|
||||
for i in range(3):
|
||||
if len(v.shape) == 3:
|
||||
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
|
||||
else:
|
||||
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_name = args.model_path.split("/")[-1].split(".")[0]
|
||||
if not os.path.isfile(args.model_path):
|
||||
assert (
|
||||
model_name == args.model_path
|
||||
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
args.model_path = download(model_name)
|
||||
|
||||
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
||||
sample_size = MODELS_MAP[model_name]["sample_size"]
|
||||
|
||||
config = Object()
|
||||
config.sample_size = sample_size
|
||||
config.sample_rate = sample_rate
|
||||
config.latent_dim = 0
|
||||
|
||||
diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
|
||||
diffusers_state_dict = diffusers_model.state_dict()
|
||||
|
||||
orig_model = DiffusionUncond(config)
|
||||
orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
|
||||
orig_model = orig_model.diffusion_ema.eval()
|
||||
orig_model_state_dict = orig_model.state_dict()
|
||||
renamed_state_dict = rename_orig_weights(orig_model_state_dict)
|
||||
|
||||
renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
|
||||
diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
|
||||
|
||||
assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
|
||||
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
||||
|
||||
for key, value in renamed_state_dict.items():
|
||||
assert (
|
||||
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
|
||||
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
if key == "time_proj.weight":
|
||||
value = value.squeeze()
|
||||
|
||||
diffusers_state_dict[key] = value
|
||||
|
||||
diffusers_model.load_state_dict(diffusers_state_dict)
|
||||
|
||||
steps = 100
|
||||
seed = 33
|
||||
|
||||
diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
|
||||
|
||||
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
||||
step_list = get_crash_schedule(t)
|
||||
|
||||
pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
|
||||
|
||||
generator = torch.manual_seed(33)
|
||||
audio = pipe(num_inference_steps=steps, generator=generator).audios
|
||||
|
||||
generated = sampling.iplms_sample(orig_model, noise, step_list, {})
|
||||
generated = generated.clamp(-1, 1)
|
||||
|
||||
diff_sum = (generated - audio).abs().sum()
|
||||
diff_max = (generated - audio).abs().max()
|
||||
|
||||
if args.save:
|
||||
pipe.save_pretrained(args.checkpoint_path)
|
||||
|
||||
print("Diff sum", diff_sum)
|
||||
print("Diff max", diff_max)
|
||||
|
||||
assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
|
||||
|
||||
print(f"Conversion for {model_name} successful!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
|
||||
)
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -18,7 +18,7 @@ from .utils import logging
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -29,10 +29,19 @@ if is_torch_available():
|
||||
get_scheduler,
|
||||
)
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .pipelines import (
|
||||
DanceDiffusionPipeline,
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
PNDMPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
|
||||
@@ -101,17 +101,28 @@ class Timesteps(nn.Module):
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian Fourier embeddings for noise levels."""
|
||||
|
||||
def __init__(self, embedding_size: int = 256, scale: float = 1.0):
|
||||
def __init__(
|
||||
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
||||
self.log = log
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
|
||||
# to delete later
|
||||
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
||||
if set_W_to_weight:
|
||||
# to delete later
|
||||
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
||||
|
||||
self.weight = self.W
|
||||
self.weight = self.W
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.log(x)
|
||||
if self.log:
|
||||
x = torch.log(x)
|
||||
|
||||
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
||||
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
|
||||
if self.flip_sin_to_cos:
|
||||
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
||||
else:
|
||||
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
return out
|
||||
|
||||
172
src/diffusers/models/unet_1d.py
Normal file
172
src/diffusers/models/unet_1d.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
Hidden states output. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
||||
obj:`False`): Whether to flip sin to cos for fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(32, 32, 64)`): Tuple of block output channels.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
mid_channels=block_out_channels[-1],
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=None,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
|
||||
# Totally fine to add another layer with a if statement - no need for nn.Identity here
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
if len(timestep.shape) == 0:
|
||||
timestep = timestep[None]
|
||||
|
||||
timestep_embed = self.time_proj(timestep)[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_samples)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
384
src/diffusers/models/unet_1d_blocks.py
Normal file
384
src/diffusers/models/unet_1d_blocks.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, 1)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
|
||||
def get_down_block(down_block_type, out_channels, in_channels):
|
||||
if down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, in_channels, out_channels):
|
||||
if up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
|
||||
if mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels, in_channels, out_channels=None):
|
||||
super().__init__()
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import (
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .unet_blocks_flax import (
|
||||
from .unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxCrossAttnUpBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -93,6 +93,20 @@ class ImagePipelineOutput(BaseOutput):
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for audio pipelines.
|
||||
|
||||
Args:
|
||||
audios (`np.ndarray`)
|
||||
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
|
||||
denoised audio samples of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
@@ -2,6 +2,7 @@ from ..utils import is_flax_available, is_onnx_available, is_torch_available, is
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
|
||||
2
src/diffusers/pipelines/dance_diffusion/__init__.py
Normal file
2
src/diffusers/pipelines/dance_diffusion/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_dance_diffusion import DanceDiffusionPipeline
|
||||
@@ -0,0 +1,113 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||
[`IPNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 100,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_length_in_s: Optional[float] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[AudioPipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of audio samples to generate.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
|
||||
the expense of slower inference.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if sample_length_in_s is None:
|
||||
sample_length_in_s = self.unet.sample_size / self.unet.sample_rate
|
||||
|
||||
sample_size = sample_length_in_s * self.unet.sample_rate
|
||||
|
||||
down_scale_factor = 2 ** len(self.unet.up_blocks)
|
||||
if sample_size < 3 * down_scale_factor:
|
||||
raise ValueError(
|
||||
f"{sample_length_in_s} is too small. Make sure it's bigger or equal to"
|
||||
f" {3 * down_scale_factor / self.unet.sample_rate}."
|
||||
)
|
||||
|
||||
original_sample_size = int(sample_size)
|
||||
if sample_size % down_scale_factor != 0:
|
||||
sample_size = ((sample_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
|
||||
logger.info(
|
||||
f"{sample_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
|
||||
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
|
||||
" process."
|
||||
)
|
||||
sample_size = int(sample_size)
|
||||
|
||||
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(audio, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> t_t-1
|
||||
audio = self.scheduler.step(model_output, t, audio).prev_sample
|
||||
|
||||
audio = audio.clamp(-1, 1).cpu().numpy()
|
||||
|
||||
audio = audio[:, :, :original_sample_size]
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -19,6 +19,7 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
|
||||
if is_torch_available():
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
|
||||
152
src/diffusers/schedulers/scheduling_ipndm.py
Normal file
152
src/diffusers/schedulers/scheduling_ipndm.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion
|
||||
[library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296)
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
self.set_timesteps(num_train_timesteps)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# running values
|
||||
self.ets = []
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
|
||||
steps = torch.cat([steps, torch.tensor([0.0])])
|
||||
|
||||
self.betas = torch.sin(steps * math.pi / 2) ** 2
|
||||
self.alphas = (1.0 - self.betas**2) ** 0.5
|
||||
|
||||
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
|
||||
self.timesteps = timesteps.to(device)
|
||||
|
||||
self.ets = []
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
timestep_index = (self.timesteps == timestep).nonzero().item()
|
||||
prev_timestep_index = timestep_index + 1
|
||||
|
||||
ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
|
||||
self.ets.append(ets)
|
||||
|
||||
if len(self.ets) == 1:
|
||||
ets = self.ets[-1]
|
||||
elif len(self.ets) == 2:
|
||||
ets = (3 * self.ets[-1] - self.ets[-2]) / 2
|
||||
elif len(self.ets) == 3:
|
||||
ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
|
||||
else:
|
||||
ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
|
||||
alpha = self.alphas[timestep_index]
|
||||
sigma = self.betas[timestep_index]
|
||||
|
||||
next_alpha = self.alphas[prev_timestep_index]
|
||||
next_sigma = self.betas[prev_timestep_index]
|
||||
|
||||
pred = (sample - sigma * ets) / max(alpha, 1e-8)
|
||||
prev_sample = next_alpha * pred + ets * next_sigma
|
||||
|
||||
return prev_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -122,6 +137,21 @@ class DiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -242,6 +272,21 @@ class DDPMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class KarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
101
tests/pipelines/dance_diffusion/test_dance_diffusion.py
Normal file
101
tests/pipelines/dance_diffusion/test_dance_diffusion.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
||||
from diffusers.utils import slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet1DModel(
|
||||
block_out_channels=(32, 32, 64),
|
||||
extra_in_channels=16,
|
||||
sample_size=512,
|
||||
sample_rate=16_000,
|
||||
in_channels=2,
|
||||
out_channels=2,
|
||||
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
|
||||
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
|
||||
)
|
||||
return model
|
||||
|
||||
def test_dance_diffusion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
scheduler = IPNDMScheduler()
|
||||
|
||||
pipe = DanceDiffusionPipeline(unet=self.dummy_unet, scheduler=scheduler)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(generator=generator, num_inference_steps=4)
|
||||
audio = output.audios
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(generator=generator, num_inference_steps=4, return_dict=False)
|
||||
audio_from_tuple = output[0]
|
||||
|
||||
audio_slice = audio[0, -3:, -3:]
|
||||
audio_from_tuple_slice = audio_from_tuple[0, -3:, -3:]
|
||||
|
||||
assert audio.shape == (1, 2, self.dummy_unet.sample_size)
|
||||
expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])
|
||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(audio_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class PipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_dance_diffusion(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096)
|
||||
audio = output.audios
|
||||
|
||||
audio_slice = audio[0, -3:, -3:]
|
||||
|
||||
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
||||
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
|
||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||
45
tests/test_models_unet_1d.py
Normal file
45
tests/test_models_unet_1d.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils import slow, torch_device
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class UnetModel1DTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
model_id = "harmonai/maestro-150k"
|
||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
|
||||
model.to(torch_device)
|
||||
|
||||
sample_size = 65536
|
||||
noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device)
|
||||
timestep = torch.tensor([1]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, timestep).sample
|
||||
|
||||
output_sum = output.abs().sum()
|
||||
output_max = output.abs().max()
|
||||
|
||||
assert (output_sum - 224.0896).abs() < 4e-2
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
||||
@property
|
||||
@@ -19,7 +19,14 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
IPNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
)
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -203,6 +210,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", 50)
|
||||
|
||||
timestep = 0
|
||||
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
|
||||
timestep = 1
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -215,14 +226,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
|
||||
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
|
||||
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -901,3 +912,157 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_sum.item() - 1006.388) < 1e-2
|
||||
assert abs(result_mean.item() - 1.31) < 1e-3
|
||||
|
||||
|
||||
class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (IPNDMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {"num_train_timesteps": 1000}
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if time_step is None:
|
||||
time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if time_step is None:
|
||||
time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
time_step_0 = scheduler.timesteps[5]
|
||||
time_step_1 = scheduler.timesteps[6]
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps, time_step=None)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=None)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 2540529) < 10
|
||||
|
||||
Reference in New Issue
Block a user