1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge remote-tracking branch 'origin/main'

This commit is contained in:
anton-l
2022-06-15 14:37:04 +02:00
17 changed files with 221 additions and 326 deletions

View File

@@ -53,17 +53,23 @@ The class provides functionality to compute previous image according to alpha, b
## Quickstart
### Installation
**Note**: If you want to run PyTorch on GPU on a CUDA-compatible machine, please make sure to install the corresponding `torch` version from the
[official website](https://pytorch.org/).
```
git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install -e .
```
### 1. `diffusers` as a central modular diffusion and sampler library
### 1. `diffusers` as a toolbox for schedulers and models.
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
Both models and schedulers should be load- and saveable from the Hub.
For more examples see [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) and [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)
#### **Example for [DDPM](https://arxiv.org/abs/2006.11239):**
```python
@@ -82,29 +88,29 @@ unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = torch.randn(
(1, unet.in_channels, unet.resolution, unet.resolution),
generator=generator,
(1, unet.in_channels, unet.resolution, unet.resolution),
generator=generator,
)
image = image.to(torch_device)
# 3. Denoise
num_prediction_steps = len(noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# predict noise residual
with torch.no_grad():
residual = unet(image, t)
# predict noise residual
with torch.no_grad():
residual = unet(image, t)
# predict previous mean of image x_t-1
pred_prev_image = noise_scheduler.step(residual, image, t)
# predict previous mean of image x_t-1
pred_prev_image = noise_scheduler.step(residual, image, t)
# optionally sample variance
variance = 0
if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * noise
# optionally sample variance
variance = 0
if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = noise_scheduler.get_variance(t).sqrt() * noise
# set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 5. process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
@@ -171,25 +177,35 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...)
`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference:
### 2. `diffusers` as a collection of popula Diffusion systems (GLIDE, Dalle, ...)
#### **Example image generation with DDPM**
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
#### **Example image generation with PNDM**
```python
from diffusers import DiffusionPipeline
from diffusers import PNDM, UNetModel, PNDMScheduler
import PIL.Image
import numpy as np
import torch
model_id = "fusing/ddim-celeba-hq"
model = UNetModel.from_pretrained(model_id)
scheduler = PNDMScheduler()
# load model and scheduler
ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom")
ddpm = PNDM(unet=model, noise_scheduler=scheduler)
# run pipeline in inference (sample random noise and denoise)
image = ddpm()
with torch.no_grad():
image = ddpm()
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = (image_processed + 1.0) / 2
image_processed = torch.clamp(image_processed, 0.0, 1.0)
image_processed = image_processed * 255
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])

View File

@@ -144,9 +144,11 @@ if __name__ == "__main__":
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()

View File

@@ -87,7 +87,6 @@ _deps = [
"regex!=2019.12.17",
"requests",
"torch>=1.4",
"torchvision",
]
# this is a lookup table with items like:
@@ -172,7 +171,6 @@ install_requires = [
deps["regex"],
deps["requests"],
deps["torch"],
deps["torchvision"],
deps["Pillow"],
]

View File

@@ -6,10 +6,10 @@ __version__ = "0.0.3"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -226,7 +226,7 @@ class ConfigMixin:
return json.loads(text)
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> Dict[str, Any]:

View File

@@ -13,5 +13,4 @@ deps = {
"regex": "regex!=2019.12.17",
"requests": "requests",
"torch": "torch>=1.4",
"torchvision": "torchvision",
}

View File

@@ -17,6 +17,6 @@
# limitations under the License.
from .unet import UNetModel
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_grad_tts import UNetGradTTSModel

View File

@@ -26,7 +26,6 @@ from torch.optim import Adam
from torch.utils import data
from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm
from ..configuration_utils import ConfigMixin
@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
h = nonlinearity(h)
h = self.conv_out(h)
return h
# dataset classes
class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class Trainer(object):
def __init__(
self,
diffusion_model,
folder,
*,
ema_decay=0.995,
image_size=128,
train_batch_size=32,
train_lr=1e-4,
train_num_steps=100000,
gradient_accumulate_every=2,
amp=False,
step_start_ema=2000,
update_ema_every=10,
save_and_sample_every=1000,
results_folder="./results",
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.step_start_ema = step_start_ema
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.image_size = diffusion_model.image_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.ds = Dataset(folder, image_size)
self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True))
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
self.step = 0
self.amp = amp
self.scaler = GradScaler(enabled=amp)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok=True)
self.reset_parameters()
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
def save(self, milestone):
data = {
"step": self.step,
"model": self.model.state_dict(),
"ema": self.ema_model.state_dict(),
"scaler": self.scaler.state_dict(),
}
torch.save(data, str(self.results_folder / f"model-{milestone}.pt"))
def load(self, milestone):
data = torch.load(str(self.results_folder / f"model-{milestone}.pt"))
self.step = data["step"]
self.model.load_state_dict(data["model"])
self.ema_model.load_state_dict(data["ema"])
self.scaler.load_state_dict(data["scaler"])
def train(self):
with tqdm(initial=self.step, total=self.train_num_steps) as pbar:
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
with autocast(enabled=self.amp):
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
pbar.set_description(f"loss: {loss.item():.4f}")
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
if self.step % self.update_ema_every == 0:
self.step_ema()
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6)
self.save(milestone)
self.step += 1
pbar.update(1)
print("training complete")

View File

@@ -2,6 +2,7 @@ import math
import torch
try:
from einops import rearrange, repeat
except:
@@ -11,6 +12,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
padding=1), torch.nn.GroupNorm(
groups, dim_out), Mish())
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, mask):
output = self.block(x * mask)
@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
dim_out))
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
@@ -83,18 +84,16 @@ class LinearAttention(torch.nn.Module):
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
heads = self.heads, qkv=3)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
heads=self.heads, h=h, w=w)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out)
@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(
self,
dim,
dim_mults=(1, 2, 4),
groups=8,
n_spks=None,
spk_emb_dim=64,
n_feats=80,
pe_scale=1000
):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__()
self.register(
@@ -143,22 +133,22 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
n_feats=n_feats,
pe_scale=pe_scale
pe_scale=pe_scale,
)
self.dim = dim
self.dim_mults = dim_mults
self.groups = groups
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
self.spk_emb_dim = spk_emb_dim
self.pe_scale = pe_scale
if n_spks > 1:
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
torch.nn.Linear(spk_emb_dim * 4, n_feats))
self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
torch.nn.Linear(dim * 4, dim))
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(torch.nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
@@ -180,18 +175,23 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(torch.nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in)]))
self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in),
]
)
)
self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None):
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)
t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.mlp(t)
@@ -230,4 +230,4 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
x = self.final_block(x, mask)
output = self.final_conv(x * mask)
return (output * mask).squeeze(1)
return (output * mask).squeeze(1)

View File

@@ -57,14 +57,14 @@ class DiffusionPipeline(ConfigMixin):
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library
library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name.
@@ -160,10 +160,10 @@ class DiffusionPipeline(ConfigMixin):
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
# import it here to avoid circular import
from diffusers import pipelines
# 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)

View File

@@ -1,6 +1,6 @@
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM
from .pipeline_pndm import PNDM

View File

@@ -283,7 +283,7 @@ class BDDM(DiffusionPipeline):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length)

View File

@@ -24,11 +24,15 @@ import torch.utils.checkpoint
from torch import nn
import tqdm
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
try:
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
except:
print("Transformers is not installed")
pass
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
@@ -832,9 +836,7 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = torch.randn(
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
).to(torch_device)
image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
# 2. Encode tokens
# an empty input is needed to guide the model away from it

View File

@@ -39,14 +39,13 @@ def generate_path(duration, mask):
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
[1, 0], [0, 0]]))[:, :-1]
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss
@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
n_layers, p_dropout):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super(ConvReluNorm, self).__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
kernel_size, padding=kernel_size//2))
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
kernel_size, padding=kernel_size//2))
self.conv_layers.append(
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
kernel_size, padding=kernel_size//2)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
kernel_size, padding=kernel_size//2)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None,
heads_share=True, p_dropout=0.0, proximal_bias=False,
proximal_init=False):
def __init__(
self,
channels,
out_channels,
n_heads,
window_size=None,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0
@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_k = torch.nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
)
self.emb_rel_v = torch.nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
@@ -171,12 +177,12 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
dtype=scores.dtype)
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights,
value_relative_embeddings)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = torch.nn.functional.pad(
relative_embeddings, convert_pad_shape([[0, 0],
[pad_length, pad_length], [0, 0]]))
relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,
slice_start_position:slice_end_position]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
p_dropout=0.0):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
super(FFN, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -267,10 +269,8 @@ class FFN(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
padding=kernel_size//2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
padding=kernel_size//2)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
@@ -282,8 +282,17 @@ class FFN(nn.Module):
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=None,
**kwargs,
):
super(Encoder, self).__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@@ -299,11 +308,15 @@ class Encoder(nn.Module):
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
n_heads, window_size=window_size, p_dropout=p_dropout))
self.attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
filter_channels, kernel_size, p_dropout=p_dropout))
self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
@@ -321,9 +334,21 @@ class Encoder(nn.Module):
class TextEncoder(ModelMixin, ConfigMixin):
def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
filter_channels_dp, n_heads, n_layers, kernel_size,
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
def __init__(
self,
n_vocab,
n_feats,
n_channels,
filter_channels,
filter_channels_dp,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=None,
spk_emb_dim=64,
n_spks=1,
):
super(TextEncoder, self).__init__()
self.register(
@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
p_dropout=p_dropout,
window_size=window_size,
spk_emb_dim=spk_emb_dim,
n_spks=n_spks
n_spks=n_spks,
)
self.n_vocab = n_vocab
self.n_feats = n_feats
self.n_channels = n_channels
@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
kernel_size=5, n_layers=3, p_dropout=0.5)
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5)
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
kernel_size, p_dropout, window_size=window_size)
self.encoder = Encoder(
n_channels + (spk_emb_dim if n_spks > 1 else 0),
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
)
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
kernel_size, p_dropout)
self.proj_w = DurationPredictor(
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
)
def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels)

View File

@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_predicted_image=clip_predicted_image,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
self.variance_type = variance_type

View File

@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps]
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
x_delta = (at_next - at) * (
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
)
x_next = x + x_delta
return x_next

View File

@@ -19,7 +19,18 @@ import unittest
import torch
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers import (
BDDM,
DDIM,
DDPM,
GLIDE,
PNDM,
DDIMScheduler,
DDPMScheduler,
LatentDiffusion,
PNDMScheduler,
UNetModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave