From f8cd3a20e47be25e96929ed2929f2e402fc2d074 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 12:25:48 +0200 Subject: [PATCH 1/6] Update README.md --- README.md | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index d8ff34e7ee..f97838df89 100644 --- a/README.md +++ b/README.md @@ -58,12 +58,14 @@ git clone https://github.com/huggingface/diffusers.git cd diffusers && pip install -e . ``` -### 1. `diffusers` as a central modular diffusion and sampler library +### 1. `diffusers` as a toolbox for schedulers and models. `diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases. It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case. Both models and schedulers should be load- and saveable from the Hub. +For more examples see [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) and [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) + #### **Example for [DDPM](https://arxiv.org/abs/2006.11239):** ```python @@ -171,25 +173,35 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` -### 2. `diffusers` as a collection of most important Diffusion systems (GLIDE, Dalle, ...) -`models` directory in repository hosts the complete code necessary for running a diffusion system as well as to train it. A `DiffusionPipeline` class allows to easily run the diffusion model in inference: +### 2. `diffusers` as a collection of popula Diffusion systems (GLIDE, Dalle, ...) -#### **Example image generation with DDPM** +For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). + +#### **Example image generation with PNDM** ```python -from diffusers import DiffusionPipeline +from diffusers import PNDM, UNetModel, PNDMScheduler import PIL.Image import numpy as np +import torch + +model_id = "fusing/ddim-celeba-hq" + +model = UNetModel.from_pretrained(model_id) +scheduler = PNDMScheduler() # load model and scheduler -ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom") +ddpm = PNDM(unet=model, noise_scheduler=scheduler) # run pipeline in inference (sample random noise and denoise) -image = ddpm() +with torch.no_grad(): + image = ddpm() # process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) -image_processed = (image_processed + 1.0) * 127.5 +image_processed = (image_processed + 1.0) / 2 +image_processed = torch.clamp(image_processed, 0.0, 1.0) +image_processed = image_processed * 255 image_processed = image_processed.numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) From 17c574a16dd505d3b280f39681e9715b7e252194 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 12:35:47 +0200 Subject: [PATCH 2/6] remove torchvision dependency --- examples/train_ddpm.py | 8 +- setup.py | 2 - src/diffusers/__init__.py | 8 +- src/diffusers/configuration_utils.py | 2 +- src/diffusers/dependency_versions_table.py | 1 - src/diffusers/models/__init__.py | 4 +- src/diffusers/models/unet.py | 169 ------------------- src/diffusers/models/unet_grad_tts.py | 82 ++++----- src/diffusers/pipeline_utils.py | 10 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/pipeline_bddm.py | 2 +- src/diffusers/pipelines/pipeline_glide.py | 4 +- src/diffusers/pipelines/pipeline_grad_tts.py | 153 ++++++++++------- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 9 +- tests/test_modeling_utils.py | 13 +- 16 files changed, 174 insertions(+), 299 deletions(-) diff --git a/examples/train_ddpm.py b/examples/train_ddpm.py index 7eb0b9d34e..2aa45b6786 100644 --- a/examples/train_ddpm.py +++ b/examples/train_ddpm.py @@ -144,9 +144,11 @@ if __name__ == "__main__": type=str, default="no", choices=["no", "fp16", "bf16"], - help="Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU.", + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), ) args = parser.parse_args() diff --git a/setup.py b/setup.py index da6058221f..14e39af7a0 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,6 @@ _deps = [ "regex!=2019.12.17", "requests", "torch>=1.4", - "torchvision", ] # this is a lookup table with items like: @@ -172,7 +171,6 @@ install_requires = [ deps["regex"], deps["requests"], deps["torch"], - deps["torchvision"], deps["Pillow"], ] diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2f4d2ab6dc..929436345e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -6,10 +6,10 @@ __version__ = "0.0.3" from .modeling_utils import ModelMixin from .models.unet import UNetModel -from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel -from .models.unet_ldm import UNetLDMModel +from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_grad_tts import UNetGradTTSModel +from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM -from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler +from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion +from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 5ba5ddec28..a74b5ea00a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -226,7 +226,7 @@ class ConfigMixin: return json.loads(text) def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" + return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index b972b9a0a6..5793d2fc85 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -13,5 +13,4 @@ deps = { "regex": "regex!=2019.12.17", "requests": "requests", "torch": "torch>=1.4", - "torchvision": "torchvision", } diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 9104bb9031..1a657f224e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -17,6 +17,6 @@ # limitations under the License. from .unet import UNetModel -from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel +from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel +from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel -from .unet_grad_tts import UNetGradTTSModel \ No newline at end of file diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 5621171149..0f7559ecf4 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -26,7 +26,6 @@ from torch.optim import Adam from torch.utils import data from PIL import Image -from torchvision import transforms, utils from tqdm import tqdm from ..configuration_utils import ConfigMixin @@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin): h = nonlinearity(h) h = self.conv_out(h) return h - - -# dataset classes - - -class Dataset(data.Dataset): - def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): - super().__init__() - self.folder = folder - self.image_size = image_size - self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] - - self.transform = transforms.Compose( - [ - transforms.Resize(image_size), - transforms.RandomHorizontalFlip(), - transforms.CenterCrop(image_size), - transforms.ToTensor(), - ] - ) - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = self.paths[index] - img = Image.open(path) - return self.transform(img) - - -# trainer class -class EMA: - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_model_average(self, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = self.update_average(old_weight, up_weight) - - def update_average(self, old, new): - if old is None: - return new - return old * self.beta + (1 - self.beta) * new - - -def cycle(dl): - while True: - for data_dl in dl: - yield data_dl - - -def num_to_groups(num, divisor): - groups = num // divisor - remainder = num % divisor - arr = [divisor] * groups - if remainder > 0: - arr.append(remainder) - return arr - - -class Trainer(object): - def __init__( - self, - diffusion_model, - folder, - *, - ema_decay=0.995, - image_size=128, - train_batch_size=32, - train_lr=1e-4, - train_num_steps=100000, - gradient_accumulate_every=2, - amp=False, - step_start_ema=2000, - update_ema_every=10, - save_and_sample_every=1000, - results_folder="./results", - ): - super().__init__() - self.model = diffusion_model - self.ema = EMA(ema_decay) - self.ema_model = copy.deepcopy(self.model) - self.update_ema_every = update_ema_every - - self.step_start_ema = step_start_ema - self.save_and_sample_every = save_and_sample_every - - self.batch_size = train_batch_size - self.image_size = diffusion_model.image_size - self.gradient_accumulate_every = gradient_accumulate_every - self.train_num_steps = train_num_steps - - self.ds = Dataset(folder, image_size) - self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True)) - self.opt = Adam(diffusion_model.parameters(), lr=train_lr) - - self.step = 0 - - self.amp = amp - self.scaler = GradScaler(enabled=amp) - - self.results_folder = Path(results_folder) - self.results_folder.mkdir(exist_ok=True) - - self.reset_parameters() - - def reset_parameters(self): - self.ema_model.load_state_dict(self.model.state_dict()) - - def step_ema(self): - if self.step < self.step_start_ema: - self.reset_parameters() - return - self.ema.update_model_average(self.ema_model, self.model) - - def save(self, milestone): - data = { - "step": self.step, - "model": self.model.state_dict(), - "ema": self.ema_model.state_dict(), - "scaler": self.scaler.state_dict(), - } - torch.save(data, str(self.results_folder / f"model-{milestone}.pt")) - - def load(self, milestone): - data = torch.load(str(self.results_folder / f"model-{milestone}.pt")) - - self.step = data["step"] - self.model.load_state_dict(data["model"]) - self.ema_model.load_state_dict(data["ema"]) - self.scaler.load_state_dict(data["scaler"]) - - def train(self): - with tqdm(initial=self.step, total=self.train_num_steps) as pbar: - - while self.step < self.train_num_steps: - for i in range(self.gradient_accumulate_every): - data = next(self.dl).cuda() - - with autocast(enabled=self.amp): - loss = self.model(data) - self.scaler.scale(loss / self.gradient_accumulate_every).backward() - - pbar.set_description(f"loss: {loss.item():.4f}") - - self.scaler.step(self.opt) - self.scaler.update() - self.opt.zero_grad() - - if self.step % self.update_ema_every == 0: - self.step_ema() - - if self.step != 0 and self.step % self.save_and_sample_every == 0: - self.ema_model.eval() - - milestone = self.step // self.save_and_sample_every - batches = num_to_groups(36, self.batch_size) - all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches)) - all_images = torch.cat(all_images_list, dim=0) - utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6) - self.save(milestone) - - self.step += 1 - pbar.update(1) - - print("training complete") diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index de2d6aa2f1..08501c4b60 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -2,6 +2,7 @@ import math import torch + try: from einops import rearrange, repeat except: @@ -11,6 +12,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin + class Mish(torch.nn.Module): def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) @@ -47,9 +49,9 @@ class Rezero(torch.nn.Module): class Block(torch.nn.Module): def __init__(self, dim, dim_out, groups=8): super(Block, self).__init__() - self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, - padding=1), torch.nn.GroupNorm( - groups, dim_out), Mish()) + self.block = torch.nn.Sequential( + torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() + ) def forward(self, x, mask): output = self.block(x * mask) @@ -59,8 +61,7 @@ class Block(torch.nn.Module): class ResnetBlock(torch.nn.Module): def __init__(self, dim, dim_out, time_emb_dim, groups=8): super(ResnetBlock, self).__init__() - self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, - dim_out)) + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) @@ -83,18 +84,16 @@ class LinearAttention(torch.nn.Module): self.heads = heads hidden_dim = dim_head * heads self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', - heads = self.heads, qkv=3) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', - heads=self.heads, h=h, w=w) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) return self.to_out(out) @@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module): class UNetGradTTSModel(ModelMixin, ConfigMixin): - def __init__( - self, - dim, - dim_mults=(1, 2, 4), - groups=8, - n_spks=None, - spk_emb_dim=64, - n_feats=80, - pe_scale=1000 - ): + def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): super(UNetGradTTSModel, self).__init__() self.register( @@ -143,22 +133,22 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): n_spks=n_spks, spk_emb_dim=spk_emb_dim, n_feats=n_feats, - pe_scale=pe_scale + pe_scale=pe_scale, ) - + self.dim = dim self.dim_mults = dim_mults self.groups = groups self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 self.spk_emb_dim = spk_emb_dim self.pe_scale = pe_scale - + if n_spks > 1: - self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), - torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.spk_mlp = torch.nn.Sequential( + torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) + ) self.time_pos_emb = SinusoidalPosEmb(dim) - self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), - torch.nn.Linear(dim * 4, dim)) + self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) @@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - self.downs.append(torch.nn.ModuleList([ - ResnetBlock(dim_in, dim_out, time_emb_dim=dim), - ResnetBlock(dim_out, dim_out, time_emb_dim=dim), - Residual(Rezero(LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else torch.nn.Identity()])) + self.downs.append( + torch.nn.ModuleList( + [ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else torch.nn.Identity(), + ] + ) + ) mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) @@ -180,18 +175,23 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - self.ups.append(torch.nn.ModuleList([ - ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), - ResnetBlock(dim_in, dim_in, time_emb_dim=dim), - Residual(Rezero(LinearAttention(dim_in))), - Upsample(dim_in)])) + self.ups.append( + torch.nn.ModuleList( + [ + ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in), + ] + ) + ) self.final_block = Block(dim, dim) self.final_conv = torch.nn.Conv2d(dim, 1, 1) def forward(self, x, mask, mu, t, spk=None): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - + t = self.time_pos_emb(t, scale=self.pe_scale) t = self.mlp(t) @@ -230,4 +230,4 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): x = self.final_block(x, mask) output = self.final_conv(x * mask) - return (output * mask).squeeze(1) \ No newline at end of file + return (output * mask).squeeze(1) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 77be534009..ceae102d7a 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -57,14 +57,14 @@ class DiffusionPipeline(ConfigMixin): def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines - + for name, module in kwargs.items(): # check if the module is a pipeline module is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1]) - + # retrive library library = module.__module__.split(".")[0] - + # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # so we set the library to module name. @@ -160,10 +160,10 @@ class DiffusionPipeline(ConfigMixin): init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} - + # import it here to avoid circular import from diffusers import pipelines - + # 4. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e0d2bf2e30..fdeccda8fe 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,6 +1,6 @@ +from .pipeline_bddm import BDDM from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM -from .pipeline_pndm import PNDM from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion -from .pipeline_bddm import BDDM +from .pipeline_pndm import PNDM diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index ee9e628f4d..de0689cea1 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -283,7 +283,7 @@ class BDDM(DiffusionPipeline): torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.diffwave.to(torch_device) - + mel_spectrogram = mel_spectrogram.to(torch_device) audio_length = mel_spectrogram.size(-1) * 256 audio_size = (1, 1, audio_length) diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 138ce9d2f2..0bd96bbccf 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline): # 1. Sample gaussian noise batch_size = 2 # second image is empty for classifier-free guidance - image = torch.randn( - (batch_size, self.text_unet.in_channels, 64, 64), generator=generator - ).to(torch_device) + image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device) # 2. Encode tokens # an empty input is needed to guide the model away from it diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 2d8f694638..048db3785f 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -39,14 +39,13 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], - [1, 0], [0, 0]]))[:, :-1] + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path def duration_loss(logw, logw_, lengths): - loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) return loss @@ -62,7 +61,7 @@ class LayerNorm(nn.Module): def forward(self, x): n_dims = len(x.shape) mean = torch.mean(x, 1, keepdim=True) - variance = torch.mean((x - mean)**2, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) x = (x - mean) * torch.rsqrt(variance + self.eps) @@ -72,8 +71,7 @@ class LayerNorm(nn.Module): class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, - n_layers, p_dropout): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): super(ConvReluNorm, self).__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module): self.conv_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList() - self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, - kernel_size, padding=kernel_size//2)) + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) for _ in range(n_layers - 1): - self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, - kernel_size, padding=kernel_size//2)) + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() @@ -114,11 +112,9 @@ class DurationPredictor(nn.Module): self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, - kernel_size, padding=kernel_size//2) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, - kernel_size, padding=kernel_size//2) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) @@ -136,9 +132,17 @@ class DurationPredictor(nn.Module): class MultiHeadAttention(nn.Module): - def __init__(self, channels, out_channels, n_heads, window_size=None, - heads_share=True, p_dropout=0.0, proximal_bias=False, - proximal_init=False): + def __init__( + self, + channels, + out_channels, + n_heads, + window_size=None, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): super(MultiHeadAttention, self).__init__() assert channels % n_heads == 0 @@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module): if window_size is not None: n_heads_rel = 1 if heads_share else n_heads rel_stddev = self.k_channels**-0.5 - self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, - window_size * 2 + 1, self.k_channels) * rel_stddev) - self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, - window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_k = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.emb_rel_v = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev + ) self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) self.drop = torch.nn.Dropout(p_dropout) @@ -171,12 +177,12 @@ class MultiHeadAttention(nn.Module): self.conv_k.weight.data.copy_(self.conv_q.weight.data) self.conv_k.bias.data.copy_(self.conv_q.bias.data) torch.nn.init.xavier_uniform_(self.conv_v.weight) - + def forward(self, x, c, attn_mask=None): q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) - + x, self.attn = self.attention(q, k, v, mask=attn_mask) x = self.conv_o(x) @@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module): scores = scores + scores_local if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, - dtype=scores.dtype) + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) p_attn = torch.nn.functional.softmax(scores, dim=-1) @@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module): if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values(relative_weights, - value_relative_embeddings) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) output = output.transpose(2, 3).contiguous().view(b, d, t_t) return output, p_attn @@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module): slice_end_position = slice_start_position + 2 * length - 1 if pad_length > 0: padded_relative_embeddings = torch.nn.functional.pad( - relative_embeddings, convert_pad_shape([[0, 0], - [pad_length, pad_length], [0, 0]])) + relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]) + ) else: padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[:, - slice_start_position:slice_end_position] + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] return used_relative_embeddings def _relative_position_to_absolute_position(self, x): batch, heads, length, _ = x.size() - x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) - x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] return x_final def _absolute_position_to_relative_position(self, x): batch, heads, length, _ = x.size() - x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) - x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) + x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) - x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final def _attention_bias_proximal(self, length): @@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module): class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, - p_dropout=0.0): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): super(FFN, self).__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -267,10 +269,8 @@ class FFN(nn.Module): self.kernel_size = kernel_size self.p_dropout = p_dropout - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, - padding=kernel_size//2) - self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, - padding=kernel_size//2) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) self.drop = torch.nn.Dropout(p_dropout) def forward(self, x, x_mask): @@ -282,8 +282,17 @@ class FFN(nn.Module): class Encoder(nn.Module): - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, - kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=None, + **kwargs, + ): super(Encoder, self).__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -299,11 +308,15 @@ class Encoder(nn.Module): self.ffn_layers = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList() for _ in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, - n_heads, window_size=window_size, p_dropout=p_dropout)) + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout + ) + ) self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append(FFN(hidden_channels, hidden_channels, - filter_channels, kernel_size, p_dropout=p_dropout)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout) + ) self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask): @@ -321,9 +334,21 @@ class Encoder(nn.Module): class TextEncoder(ModelMixin, ConfigMixin): - def __init__(self, n_vocab, n_feats, n_channels, filter_channels, - filter_channels_dp, n_heads, n_layers, kernel_size, - p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): + def __init__( + self, + n_vocab, + n_feats, + n_channels, + filter_channels, + filter_channels_dp, + n_heads, + n_layers, + kernel_size, + p_dropout, + window_size=None, + spk_emb_dim=64, + n_spks=1, + ): super(TextEncoder, self).__init__() self.register( @@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin): p_dropout=p_dropout, window_size=window_size, spk_emb_dim=spk_emb_dim, - n_spks=n_spks + n_spks=n_spks, ) - - + self.n_vocab = n_vocab self.n_feats = n_feats self.n_channels = n_channels @@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin): self.emb = torch.nn.Embedding(n_vocab, n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) - self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, - kernel_size=5, n_layers=3, p_dropout=0.5) + self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5) - self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, - kernel_size, p_dropout, window_size=window_size) + self.encoder = Encoder( + n_channels + (spk_emb_dim if n_spks > 1 else 0), + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + window_size=window_size, + ) self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) - self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, - kernel_size, p_dropout) + self.proj_w = DurationPredictor( + n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout + ) def forward(self, x, x_lengths, spk=None): x = self.emb(x) * math.sqrt(self.n_channels) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 81f0f849dc..e2fc289046 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): clip_predicted_image=clip_predicted_image, ) self.timesteps = int(timesteps) - self.timestep_values = timestep_values # save the fixed timestep values for BDDM + self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image self.variance_type = variance_type diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 85fa6fb2f5..ee6f3bcf7a 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) - warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) + warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order + ) self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) return self.warmup_time_steps[num_inference_steps] @@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): at = alphas_cump[t + 1].view(-1, 1, 1, 1) at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) - x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) + x_delta = (at_next - at) * ( + (1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x + - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et + ) x_next = x + x_delta return x_next diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index b3dd5ef64a..e78f01a458 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,18 @@ import unittest import torch -from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel +from diffusers import ( + BDDM, + DDIM, + DDPM, + GLIDE, + PNDM, + DDIMScheduler, + DDPMScheduler, + LatentDiffusion, + PNDMScheduler, + UNetModel, +) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_bddm import DiffWave From 1ab81f3b5ba23631a22edb068587d40d1d15e055 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 12:41:57 +0200 Subject: [PATCH 3/6] Update README.md --- README.md | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f97838df89..2b4ce223f8 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,10 @@ The class provides functionality to compute previous image according to alpha, b ## Quickstart +### Installation + +**Note**: If you want to run PyTorch on GPU on a CUDA-compatible machine, please make sure to install the corresponding `torch` version from the +[official website]( ``` git clone https://github.com/huggingface/diffusers.git cd diffusers && pip install -e . @@ -84,29 +88,29 @@ unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise image = torch.randn( - (1, unet.in_channels, unet.resolution, unet.resolution), - generator=generator, + (1, unet.in_channels, unet.resolution, unet.resolution), + generator=generator, ) image = image.to(torch_device) # 3. Denoise num_prediction_steps = len(noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): - # predict noise residual - with torch.no_grad(): - residual = unet(image, t) + # predict noise residual + with torch.no_grad(): + residual = unet(image, t) - # predict previous mean of image x_t-1 - pred_prev_image = noise_scheduler.step(residual, image, t) + # predict previous mean of image x_t-1 + pred_prev_image = noise_scheduler.step(residual, image, t) - # optionally sample variance - variance = 0 - if t > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = noise_scheduler.get_variance(t).sqrt() * noise + # optionally sample variance + variance = 0 + if t > 0: + noise = torch.randn(image.shape, generator=generator).to(image.device) + variance = noise_scheduler.get_variance(t).sqrt() * noise - # set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance + # set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # 5. process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) From 22ab275526d11139e739d4e8f19953b6626e541d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 13:08:56 +0200 Subject: [PATCH 4/6] make transformes soft --- src/diffusers/pipelines/pipeline_glide.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 0bd96bbccf..b3fdc290f9 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -24,11 +24,15 @@ import torch.utils.checkpoint from torch import nn import tqdm -from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +try: + from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer + from transformers.activations import ACT2FN + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling + from transformers.modeling_utils import PreTrainedModel + from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +except: + print("Transformers is not installed") + pass from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..pipeline_utils import DiffusionPipeline From cee56cc7203c5cc1b228ee13841a918b93431713 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 13:23:02 +0200 Subject: [PATCH 5/6] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2b4ce223f8..eba9e09410 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ The class provides functionality to compute previous image according to alpha, b ### Installation **Note**: If you want to run PyTorch on GPU on a CUDA-compatible machine, please make sure to install the corresponding `torch` version from the -[official website]( +[official website](https://pytorch.org/). ``` git clone https://github.com/huggingface/diffusers.git cd diffusers && pip install -e . From 97fcc4c6cc279da03ee9e49da63599c0bf328f60 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Jun 2022 13:27:05 +0200 Subject: [PATCH 6/6] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eba9e09410..6d3ea31100 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ num_prediction_steps = len(noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # predict noise residual with torch.no_grad(): - residual = unet(image, t) + residual = unet(image, t) # predict previous mean of image x_t-1 pred_prev_image = noise_scheduler.step(residual, image, t) @@ -107,7 +107,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s variance = 0 if t > 0: noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = noise_scheduler.get_variance(t).sqrt() * noise + variance = noise_scheduler.get_variance(t).sqrt() * noise # set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance