mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove torchvision dependency
This commit is contained in:
@@ -144,9 +144,11 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU.",
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
2
setup.py
2
setup.py
@@ -87,7 +87,6 @@ _deps = [
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -172,7 +171,6 @@ install_requires = [
|
||||
deps["regex"],
|
||||
deps["requests"],
|
||||
deps["torch"],
|
||||
deps["torchvision"],
|
||||
deps["Pillow"],
|
||||
]
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ __version__ = "0.0.3"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
@@ -226,7 +226,7 @@ class ConfigMixin:
|
||||
return json.loads(text)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -13,5 +13,4 @@ deps = {
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
}
|
||||
|
||||
@@ -17,6 +17,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
@@ -26,7 +26,6 @@ from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from PIL import Image
|
||||
from torchvision import transforms, utils
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
# dataset classes
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
|
||||
# trainer class
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data_dl in dl:
|
||||
yield data_dl
|
||||
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
folder,
|
||||
*,
|
||||
ema_decay=0.995,
|
||||
image_size=128,
|
||||
train_batch_size=32,
|
||||
train_lr=1e-4,
|
||||
train_num_steps=100000,
|
||||
gradient_accumulate_every=2,
|
||||
amp=False,
|
||||
step_start_ema=2000,
|
||||
update_ema_every=10,
|
||||
save_and_sample_every=1000,
|
||||
results_folder="./results",
|
||||
):
|
||||
super().__init__()
|
||||
self.model = diffusion_model
|
||||
self.ema = EMA(ema_decay)
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.update_ema_every = update_ema_every
|
||||
|
||||
self.step_start_ema = step_start_ema
|
||||
self.save_and_sample_every = save_and_sample_every
|
||||
|
||||
self.batch_size = train_batch_size
|
||||
self.image_size = diffusion_model.image_size
|
||||
self.gradient_accumulate_every = gradient_accumulate_every
|
||||
self.train_num_steps = train_num_steps
|
||||
|
||||
self.ds = Dataset(folder, image_size)
|
||||
self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True))
|
||||
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
|
||||
|
||||
self.step = 0
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled=amp)
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
self.results_folder.mkdir(exist_ok=True)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.ema_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
def step_ema(self):
|
||||
if self.step < self.step_start_ema:
|
||||
self.reset_parameters()
|
||||
return
|
||||
self.ema.update_model_average(self.ema_model, self.model)
|
||||
|
||||
def save(self, milestone):
|
||||
data = {
|
||||
"step": self.step,
|
||||
"model": self.model.state_dict(),
|
||||
"ema": self.ema_model.state_dict(),
|
||||
"scaler": self.scaler.state_dict(),
|
||||
}
|
||||
torch.save(data, str(self.results_folder / f"model-{milestone}.pt"))
|
||||
|
||||
def load(self, milestone):
|
||||
data = torch.load(str(self.results_folder / f"model-{milestone}.pt"))
|
||||
|
||||
self.step = data["step"]
|
||||
self.model.load_state_dict(data["model"])
|
||||
self.ema_model.load_state_dict(data["ema"])
|
||||
self.scaler.load_state_dict(data["scaler"])
|
||||
|
||||
def train(self):
|
||||
with tqdm(initial=self.step, total=self.train_num_steps) as pbar:
|
||||
|
||||
while self.step < self.train_num_steps:
|
||||
for i in range(self.gradient_accumulate_every):
|
||||
data = next(self.dl).cuda()
|
||||
|
||||
with autocast(enabled=self.amp):
|
||||
loss = self.model(data)
|
||||
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
|
||||
|
||||
pbar.set_description(f"loss: {loss.item():.4f}")
|
||||
|
||||
self.scaler.step(self.opt)
|
||||
self.scaler.update()
|
||||
self.opt.zero_grad()
|
||||
|
||||
if self.step % self.update_ema_every == 0:
|
||||
self.step_ema()
|
||||
|
||||
if self.step != 0 and self.step % self.save_and_sample_every == 0:
|
||||
self.ema_model.eval()
|
||||
|
||||
milestone = self.step // self.save_and_sample_every
|
||||
batches = num_to_groups(36, self.batch_size)
|
||||
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
|
||||
all_images = torch.cat(all_images_list, dim=0)
|
||||
utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6)
|
||||
self.save(milestone)
|
||||
|
||||
self.step += 1
|
||||
pbar.update(1)
|
||||
|
||||
print("training complete")
|
||||
|
||||
@@ -2,6 +2,7 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
from einops import rearrange, repeat
|
||||
except:
|
||||
@@ -11,6 +12,7 @@ except:
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super(Block, self).__init__()
|
||||
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
|
||||
padding=1), torch.nn.GroupNorm(
|
||||
groups, dim_out), Mish())
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
|
||||
dim_out))
|
||||
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups=groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||
@@ -83,18 +84,16 @@ class LinearAttention(torch.nn.Module):
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
||||
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads = self.heads, qkv=3)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads, h=h, w=w)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
|
||||
|
||||
|
||||
class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_mults=(1, 2, 4),
|
||||
groups=8,
|
||||
n_spks=None,
|
||||
spk_emb_dim=64,
|
||||
n_feats=80,
|
||||
pe_scale=1000
|
||||
):
|
||||
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
||||
super(UNetGradTTSModel, self).__init__()
|
||||
|
||||
self.register(
|
||||
@@ -143,22 +133,22 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
n_feats=n_feats,
|
||||
pe_scale=pe_scale
|
||||
pe_scale=pe_scale,
|
||||
)
|
||||
|
||||
|
||||
self.dim = dim
|
||||
self.dim_mults = dim_mults
|
||||
self.groups = groups
|
||||
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.pe_scale = pe_scale
|
||||
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
|
||||
torch.nn.Linear(spk_emb_dim * 4, n_feats))
|
||||
self.spk_mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
|
||||
torch.nn.Linear(dim * 4, dim))
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
||||
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
self.downs.append(torch.nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
||||
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
|
||||
self.downs.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
||||
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
@@ -180,18 +175,23 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(torch.nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
||||
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in)]))
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
||||
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.final_block = Block(dim, dim)
|
||||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, x, mask, mu, t, spk=None):
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
|
||||
t = self.time_pos_emb(t, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
@@ -230,4 +230,4 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
x = self.final_block(x, mask)
|
||||
output = self.final_conv(x * mask)
|
||||
|
||||
return (output * mask).squeeze(1)
|
||||
return (output * mask).squeeze(1)
|
||||
|
||||
@@ -57,14 +57,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# check if the module is a pipeline module
|
||||
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
|
||||
|
||||
|
||||
# retrive library
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# so we set the library to module name.
|
||||
@@ -160,10 +160,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
|
||||
# 4. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .pipeline_bddm import BDDM
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_pndm import PNDM
|
||||
from .pipeline_glide import GLIDE
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_bddm import BDDM
|
||||
from .pipeline_pndm import PNDM
|
||||
|
||||
@@ -283,7 +283,7 @@ class BDDM(DiffusionPipeline):
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.diffwave.to(torch_device)
|
||||
|
||||
|
||||
mel_spectrogram = mel_spectrogram.to(torch_device)
|
||||
audio_length = mel_spectrogram.size(-1) * 256
|
||||
audio_size = (1, 1, audio_length)
|
||||
|
||||
@@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline):
|
||||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = torch.randn(
|
||||
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
|
||||
).to(torch_device)
|
||||
image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from it
|
||||
|
||||
@@ -39,14 +39,13 @@ def generate_path(duration, mask):
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
|
||||
[1, 0], [0, 0]]))[:, :-1]
|
||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
||||
def duration_loss(logw, logw_, lengths):
|
||||
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
|
||||
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
|
||||
def forward(self, x):
|
||||
n_dims = len(x.shape)
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||||
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
|
||||
@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||
n_layers, p_dropout):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super(ConvReluNorm, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
|
||||
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
self.norm_layers = torch.nn.ModuleList()
|
||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
|
||||
kernel_size, padding=kernel_size//2))
|
||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
|
||||
kernel_size, padding=kernel_size//2))
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
|
||||
kernel_size, padding=kernel_size//2)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
|
||||
kernel_size, padding=kernel_size//2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, out_channels, n_heads, window_size=None,
|
||||
heads_share=True, p_dropout=0.0, proximal_bias=False,
|
||||
proximal_init=False):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
window_size=None,
|
||||
heads_share=True,
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
|
||||
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
|
||||
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_k = torch.nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
self.emb_rel_v = torch.nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
@@ -171,12 +177,12 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
||||
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
||||
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
|
||||
dtype=scores.dtype)
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||
@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights,
|
||||
value_relative_embeddings)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
return output, p_attn
|
||||
|
||||
@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = torch.nn.functional.pad(
|
||||
relative_embeddings, convert_pad_shape([[0, 0],
|
||||
[pad_length, pad_length], [0, 0]]))
|
||||
relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:,
|
||||
slice_start_position:slice_end_position]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
batch, heads, length, _ = x.size()
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
||||
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
||||
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
batch, heads, length, _ = x.size()
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
|
||||
p_dropout=0.0):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
||||
super(FFN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@@ -267,10 +269,8 @@ class FFN(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
|
||||
padding=kernel_size//2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
|
||||
padding=kernel_size//2)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
@@ -282,8 +282,17 @@ class FFN(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
|
||||
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
window_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
@@ -299,11 +308,15 @@ class Encoder(nn.Module):
|
||||
self.ffn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_2 = torch.nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
|
||||
n_heads, window_size=window_size, p_dropout=p_dropout))
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
|
||||
filter_channels, kernel_size, p_dropout=p_dropout))
|
||||
self.ffn_layers.append(
|
||||
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
@@ -321,9 +334,21 @@ class Encoder(nn.Module):
|
||||
|
||||
|
||||
class TextEncoder(ModelMixin, ConfigMixin):
|
||||
def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
|
||||
filter_channels_dp, n_heads, n_layers, kernel_size,
|
||||
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab,
|
||||
n_feats,
|
||||
n_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
window_size=None,
|
||||
spk_emb_dim=64,
|
||||
n_spks=1,
|
||||
):
|
||||
super(TextEncoder, self).__init__()
|
||||
|
||||
self.register(
|
||||
@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
p_dropout=p_dropout,
|
||||
window_size=window_size,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
n_spks=n_spks
|
||||
n_spks=n_spks,
|
||||
)
|
||||
|
||||
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_feats = n_feats
|
||||
self.n_channels = n_channels
|
||||
@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
||||
|
||||
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
|
||||
kernel_size=5, n_layers=3, p_dropout=0.5)
|
||||
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5)
|
||||
|
||||
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
|
||||
kernel_size, p_dropout, window_size=window_size)
|
||||
self.encoder = Encoder(
|
||||
n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
|
||||
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
|
||||
kernel_size, p_dropout)
|
||||
self.proj_w = DurationPredictor(
|
||||
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
|
||||
@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_predicted_image=clip_predicted_image,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
self.variance_type = variance_type
|
||||
|
||||
|
||||
@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
|
||||
|
||||
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
|
||||
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
|
||||
|
||||
return self.warmup_time_steps[num_inference_steps]
|
||||
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
|
||||
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
|
||||
|
||||
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
|
||||
x_delta = (at_next - at) * (
|
||||
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
|
||||
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
|
||||
)
|
||||
|
||||
x_next = x + x_delta
|
||||
return x_next
|
||||
|
||||
@@ -19,7 +19,18 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
|
||||
from diffusers import (
|
||||
BDDM,
|
||||
DDIM,
|
||||
DDPM,
|
||||
GLIDE,
|
||||
PNDM,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
LatentDiffusion,
|
||||
PNDMScheduler,
|
||||
UNetModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.pipeline_bddm import DiffWave
|
||||
|
||||
Reference in New Issue
Block a user