From 418888a5665213c0921a68c98463be62754badb7 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 08:00:23 +0200 Subject: [PATCH 1/9] Pokemon DDPM training --- src/diffusers/trainers/training_ddpm.py | 39 +++++++++++++------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/diffusers/trainers/training_ddpm.py b/src/diffusers/trainers/training_ddpm.py index 6753a580e9..bc2a4d10ba 100644 --- a/src/diffusers/trainers/training_ddpm.py +++ b/src/diffusers/trainers/training_ddpm.py @@ -8,14 +8,14 @@ import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel -from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor +from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup def set_seed(seed): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = False torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) @@ -30,13 +30,13 @@ model = UNetModel( attn_resolutions=(16,), ch=128, ch_mult=(1, 2, 2, 2), - dropout=0.1, + dropout=0.0, num_res_blocks=2, resamp_with_conv=True, resolution=32 ) noise_scheduler = DDPMScheduler(timesteps=1000) -optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) +optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) num_epochs = 100 batch_size = 64 @@ -44,9 +44,10 @@ gradient_accumulation_steps = 2 augmentations = Compose( [ - Resize(32), - CenterCrop(32), RandomHorizontalFlip(), + RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1), + Resize(32, interpolation=InterpolationMode.BILINEAR), + CenterCrop(32), ToTensor(), Lambda(lambda x: x * 2 - 1), ] @@ -59,24 +60,24 @@ def transforms(examples): return {"input": images} -dataset = dataset.shuffle(seed=0) dataset.set_transform(transforms) -train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) -#lr_scheduler = get_linear_schedule_with_warmup( -# optimizer=optimizer, -# num_warmup_steps=1000, -# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, -#) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=500, + num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, +) -model, optimizer, train_dataloader = accelerator.prepare( - model, optimizer, train_dataloader +model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler ) for epoch in range(num_epochs): model.train() pbar = tqdm(total=len(train_dataloader), unit="ba") pbar.set_description(f"Epoch {epoch}") + losses = [] for step, batch in enumerate(train_dataloader): clean_images = batch["input"] noisy_images = torch.empty_like(clean_images) @@ -101,10 +102,12 @@ for epoch in range(num_epochs): accelerator.backward(loss) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - # lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad() + loss = loss.detach().item() + losses.append(loss) pbar.update(1) - pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"]) optimizer.step() From 7d8bf1a909565de0b577535d4e20ea678de2c693 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 08:51:00 +0000 Subject: [PATCH 2/9] make pndm easier --- _ | 156 +++++++++++++++++++++++ src/diffusers/pipelines/pipeline_pndm.py | 96 +++++--------- 2 files changed, 186 insertions(+), 66 deletions(-) create mode 100644 _ diff --git a/_ b/_ new file mode 100644 index 0000000000..702652c8aa --- /dev/null +++ b/_ @@ -0,0 +1,156 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import torch + +import tqdm + +from ..pipeline_utils import DiffusionPipeline + + +class PNDM(DiffusionPipeline): + def __init__(self, unet, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + + def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): + # eta corresponds to η in paper and should be between [0, 1] + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + num_trained_timesteps = self.noise_scheduler.timesteps + inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) + + self.unet.to(torch_device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + generator=generator, + ) + image = image.to(torch_device) + + seq = list(inference_step_times) + seq_next = [-1] + list(seq[:-1]) + model = self.unet + + warmup_steps = [len(seq) - (i // 4 + 1) for i in range(3 * 4)] + + ets = [] + prev_image = image + for i, step_idx in enumerate(warmup_steps): + i = seq[step_idx] + j = seq_next[step_idx] + + t = (torch.ones(image.shape[0]) * i) + t_next = (torch.ones(image.shape[0]) * j) + + residual = model(image.to("cuda"), t.to("cuda")) + residual = residual.to("cpu") + + image = image.to("cpu") + image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_list[0], t_list[1], residual) + + if i % 4 == 0: + ets.append(residual) + prev_image = image + + for + + ets = [] + step_idx = len(seq) - 1 + while step_idx >= 0: + i = seq[step_idx] + j = seq_next[step_idx] + + t = (torch.ones(image.shape[0]) * i) + t_next = (torch.ones(image.shape[0]) * j) + + residual = model(image.to("cuda"), t.to("cuda")) + residual = residual.to("cpu") + + t_list = [t, (t+t_next)/2, t_next] + + ets.append(residual) + if len(ets) <= 3: + image = image.to("cpu") + x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual) + + e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu") + x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2) + e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu") + x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3) + e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu") + residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4) + else: + residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) + + img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) + image = img_next + + step_idx = step_idx - 1 + +# if len(prev_noises) in [1, 2]: +# t = (t + t_next) / 2 +# elif len(prev_noises) == 3: +# t = t_next / 2 + +# if len(prev_noises) == 0: +# ets.append(residual) +# +# if len(ets) > 3: +# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) +# step_idx = step_idx - 1 +# elif len(ets) <= 3 and len(prev_noises) == 3: +# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual) +# prev_noises = [] +# step_idx = step_idx - 1 +# elif len(ets) <= 3 and len(prev_noises) < 3: +# prev_noises.append(residual) +# if len(prev_noises) < 2: +# t_next = (t + t_next) / 2 +# +# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) + + return image + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_image -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_image_direction -> "direction pointingc to x_t" + # - pred_prev_image -> "x_t-1" +# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + # 1. predict noise residual +# with torch.no_grad(): +# residual = self.unet(image, inference_step_times[t]) +# + # 2. predict previous mean of image x_t-1 +# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) +# + # 3. optionally sample variance +# variance = 0 +# if eta > 0: +# noise = torch.randn(image.shape, generator=generator).to(image.device) +# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise +# + # 4. set current image to prev_image: x_t -> x_t-1 +# image = pred_prev_image + variance diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index 21eb39f7c1..d0fe01ec8a 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -48,9 +48,36 @@ class PNDM(DiffusionPipeline): seq_next = [-1] + list(seq[:-1]) model = self.unet + warmup_time_steps = list(reversed([(t + 5) // 10 * 10 for t in range(seq[-4], seq[-1], 5)])) + + cur_residual = 0 + prev_image = image ets = [] - prev_noises = [] - step_idx = len(seq) - 1 + for i in range(len(warmup_time_steps)): + t = warmup_time_steps[i] * torch.ones(image.shape[0]) + t_next = (warmup_time_steps[i + 1] if i < len(warmup_time_steps) - 1 else warmup_time_steps[-1]) * torch.ones(image.shape[0]) + + residual = model(image.to("cuda"), t.to("cuda")) + residual = residual.to("cpu") + + if i % 4 == 0: + cur_residual += 1 / 6 * residual + ets.append(residual) + prev_image = image + elif (i - 1) % 4 == 0: + cur_residual += 1 / 3 * residual + elif (i - 2) % 4 == 0: + cur_residual += 1 / 3 * residual + elif (i - 3) % 4 == 0: + cur_residual += 1 / 6 * residual + residual = cur_residual + cur_residual = 0 + + image = image.to("cpu") + t_2 = warmup_time_steps[4 * (i // 4)] * torch.ones(image.shape[0]) + image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_2, t_next, residual) + + step_idx = len(seq) - 4 while step_idx >= 0: i = seq[step_idx] j = seq_next[step_idx] @@ -60,75 +87,12 @@ class PNDM(DiffusionPipeline): residual = model(image.to("cuda"), t.to("cuda")) residual = residual.to("cpu") - - t_list = [t, (t+t_next)/2, t_next] - ets.append(residual) - if len(ets) <= 3: - image = image.to("cpu") - x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual) - - e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu") - x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2) - e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu") - x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3) - e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu") - residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4) - else: - residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) + residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) image = img_next step_idx = step_idx - 1 -# if len(prev_noises) in [1, 2]: -# t = (t + t_next) / 2 -# elif len(prev_noises) == 3: -# t = t_next / 2 - -# if len(prev_noises) == 0: -# ets.append(residual) -# -# if len(ets) > 3: -# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) -# step_idx = step_idx - 1 -# elif len(ets) <= 3 and len(prev_noises) == 3: -# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual) -# prev_noises = [] -# step_idx = step_idx - 1 -# elif len(ets) <= 3 and len(prev_noises) < 3: -# prev_noises.append(residual) -# if len(prev_noises) < 2: -# t_next = (t + t_next) / 2 -# -# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) - return image - - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" -# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - # 1. predict noise residual -# with torch.no_grad(): -# residual = self.unet(image, inference_step_times[t]) -# - # 2. predict previous mean of image x_t-1 -# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) -# - # 3. optionally sample variance -# variance = 0 -# if eta > 0: -# noise = torch.randn(image.shape, generator=generator).to(image.device) -# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise -# - # 4. set current image to prev_image: x_t -> x_t-1 -# image = pred_prev_image + variance From bb3066428537da6263676448e737f315203d986c Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 11:33:24 +0200 Subject: [PATCH 3/9] Move the training example --- Makefile | 2 +- .../trainers => examples}/training_ddpm.py | 29 ++++++++++++------- tests/test_modeling_utils.py | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) rename {src/diffusers/trainers => examples}/training_ddpm.py (87%) diff --git a/Makefile b/Makefile index dad0611769..ddf143b6d4 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := tests src utils +check_dirs := examples tests src utils modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/src/diffusers/trainers/training_ddpm.py b/examples/training_ddpm.py similarity index 87% rename from src/diffusers/trainers/training_ddpm.py rename to examples/training_ddpm.py index bc2a4d10ba..b3ba111ccb 100644 --- a/src/diffusers/trainers/training_ddpm.py +++ b/examples/training_ddpm.py @@ -8,14 +8,23 @@ import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel -from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor +from torchvision.transforms import ( + Compose, + InterpolationMode, + Lambda, + RandomCrop, + RandomHorizontalFlip, + RandomVerticalFlip, + Resize, + ToTensor, +) from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup def set_seed(seed): - #torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) @@ -33,7 +42,7 @@ model = UNetModel( dropout=0.0, num_res_blocks=2, resamp_with_conv=True, - resolution=32 + resolution=32, ) noise_scheduler = DDPMScheduler(timesteps=1000) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) @@ -44,15 +53,15 @@ gradient_accumulation_steps = 2 augmentations = Compose( [ - RandomHorizontalFlip(), - RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1), Resize(32, interpolation=InterpolationMode.BILINEAR), - CenterCrop(32), + RandomHorizontalFlip(), + RandomVerticalFlip(), + RandomCrop(32), ToTensor(), Lambda(lambda x: x * 2 - 1), ] ) -dataset = load_dataset("huggan/pokemon", split="train") +dataset = load_dataset("huggan/flowers-102-categories", split="train") def transforms(examples): @@ -127,5 +136,5 @@ for epoch in range(num_epochs): image_pil = PIL.Image.fromarray(image_processed[0]) # save image - pipeline.save_pretrained("./poke-ddpm") - image_pil.save(f"./poke-ddpm/test_{epoch}.png") + pipeline.save_pretrained("./flowers-ddpm") + image_pil.save(f"./flowers-ddpm/test_{epoch}.png") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c119479fa..417ef353d6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler +from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device From d10441d877a84747d3f8e946b536107050e33f20 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 11:43:05 +0200 Subject: [PATCH 4/9] Revert config eq --- src/diffusers/configuration_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 4436445334..61a80ff1e2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -225,11 +225,11 @@ class ConfigMixin: text = reader.read() return json.loads(text) - # def __eq__(self, other): - # return self.__dict__ == other.__dict__ + def __eq__(self, other): + return self.__dict__ == other.__dict__ - # def __repr__(self): - # return f"{self.__class__.__name__} {self.to_json_string()}" + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: From 559b8cbf469f70e12bffb1607bc189ef5e14651c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:17:45 +0000 Subject: [PATCH 5/9] finish pndm --- src/diffusers/pipelines/pipeline_pndm.py | 58 +++----------- src/diffusers/schedulers/scheduling_pndm.py | 86 +++++++++++++-------- 2 files changed, 64 insertions(+), 80 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index d0fe01ec8a..1116b6042a 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - num_trained_timesteps = self.noise_scheduler.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - self.unet.to(torch_device) # Sample gaussian noise to begin loop @@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline): ) image = image.to(torch_device) - seq = list(inference_step_times) - seq_next = [-1] + list(seq[:-1]) - model = self.unet - - warmup_time_steps = list(reversed([(t + 5) // 10 * 10 for t in range(seq[-4], seq[-1], 5)])) - - cur_residual = 0 + warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) prev_image = image - ets = [] - for i in range(len(warmup_time_steps)): - t = warmup_time_steps[i] * torch.ones(image.shape[0]) - t_next = (warmup_time_steps[i + 1] if i < len(warmup_time_steps) - 1 else warmup_time_steps[-1]) * torch.ones(image.shape[0]) + for t in tqdm.tqdm(range(len(warmup_time_steps))): + t_orig = warmup_time_steps[t] + residual = self.unet(image, t_orig) - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - - if i % 4 == 0: - cur_residual += 1 / 6 * residual - ets.append(residual) + if t % 4 == 0: prev_image = image - elif (i - 1) % 4 == 0: - cur_residual += 1 / 3 * residual - elif (i - 2) % 4 == 0: - cur_residual += 1 / 3 * residual - elif (i - 3) % 4 == 0: - cur_residual += 1 / 6 * residual - residual = cur_residual - cur_residual = 0 - image = image.to("cpu") - t_2 = warmup_time_steps[4 * (i // 4)] * torch.ones(image.shape[0]) - image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_2, t_next, residual) + image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps) - step_idx = len(seq) - 4 - while step_idx >= 0: - i = seq[step_idx] - j = seq_next[step_idx] + timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) + for t in tqdm.tqdm(range(len(timesteps))): + t_orig = timesteps[t] + residual = self.unet(image, t_orig) - t = (torch.ones(image.shape[0]) * i) - t_next = (torch.ones(image.shape[0]) * j) - - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - ets.append(residual) - residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) - - img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) - image = img_next - - step_idx = step_idx - 1 + image = self.noise_scheduler.step(residual, image, t, num_inference_steps) return image diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d69e4a5188..ad639a90f1 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) + # hardcode for now + self.pndm_order = 4 + self.cur_residual = 0 + + # running values + self.ets = [] + self.warmup_time_steps = {} + self.time_steps = {} + # self.register_buffer("betas", betas.to(torch.float32)) # self.register_buffer("alphas", alphas.to(torch.float32)) # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) @@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.one return self.alphas_cumprod[time_step] - def step(self, img, t_start, t_end, model, ets): -# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets) -#def gen_order_4(img, t, t_next, model, alphas_cump, ets): - t_next, t = t_start, t_end + def get_warmup_time_steps(self, num_inference_steps): + if num_inference_steps in self.warmup_time_steps: + return self.warmup_time_steps[num_inference_steps] - noise_ = model(img.to("cuda"), t.to("cuda")) - noise_ = noise_.to("cpu") + inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) - t_list = [t, (t+t_next)/2, t_next] - if len(ets) > 2: - ets.append(noise_) - noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) - else: - noise = self.runge_kutta(img, t_list, model, ets, noise_) + warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) + self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) - img_next = self.transfer(img.to("cpu"), t, t_next, noise) - return img_next, ets + return self.warmup_time_steps[num_inference_steps] - def runge_kutta(self, x, t_list, model, ets, noise_): - model = model.to("cuda") - x = x.to("cpu") + def get_time_steps(self, num_inference_steps): + if num_inference_steps in self.time_steps: + return self.time_steps[num_inference_steps] - e_1 = noise_ - ets.append(e_1) - x_2 = self.transfer(x, t_list[0], t_list[1], e_1) + inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) + self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) - e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")) - e_2 = e_2.to("cpu") - x_3 = self.transfer(x, t_list[0], t_list[1], e_2) + return self.time_steps[num_inference_steps] - e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")) - e_3 = e_3.to("cpu") - x_4 = self.transfer(x, t_list[0], t_list[2], e_3) + def step_warm_up(self, residual, image, t, num_inference_steps): + warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) - e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")) - e_4 = e_4.to("cpu") + t_prev = warmup_time_steps[t // 4 * 4] + t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)] - et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) + if t % 4 == 0: + self.cur_residual += 1 / 6 * residual + self.ets.append(residual) + elif (t - 1) % 4 == 0: + self.cur_residual += 1 / 3 * residual + elif (t - 2) % 4 == 0: + self.cur_residual += 1 / 3 * residual + elif (t - 3) % 4 == 0: + residual = self.cur_residual + 1 / 6 * residual + self.cur_residual = 0 - return et + return self.transfer(image, t_prev, t_next, residual) + + def step(self, residual, image, t, num_inference_steps): + timesteps = self.get_time_steps(num_inference_steps) + + t_prev = timesteps[t] + t_next = timesteps[min(t + 1, len(timesteps) - 1)] + self.ets.append(residual) + + residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + return self.transfer(image, t_prev, t_next, residual) def transfer(self, x, t, t_next, et): - alphas_cump = self.alphas_cumprod - at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) - at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) + # TODO(Patrick): clean up to be compatible with numpy and give better names + + alphas_cump = self.alphas_cumprod.to(x.device) + at = alphas_cump[t + 1].view(-1, 1, 1, 1) + at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) From f0a99e76846537525cf509d7e2b16b3b91f46de7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:22:53 +0000 Subject: [PATCH 6/9] finish --- src/diffusers/schedulers/scheduling_pndm.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ad639a90f1..fa1c9ca56d 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,32 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # hardcode for now + # for now we only support F-PNDM, i.e. the runge-kutta method self.pndm_order = 4 - self.cur_residual = 0 # running values + self.cur_residual = 0 self.ets = [] self.warmup_time_steps = {} self.time_steps = {} - # self.register_buffer("betas", betas.to(torch.float32)) - # self.register_buffer("alphas", alphas.to(torch.float32)) - # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) - - # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - # TODO(PVP) - check how much of these is actually necessary! - # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... - # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 - # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - # if variance_type == "fixed_small": - # log_variance = torch.log(variance.clamp(min=1e-20)) - # elif variance_type == "fixed_large": - # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) - # - # - # self.register_buffer("log_variance", log_variance.to(torch.float32)) - def get_alpha(self, time_step): return self.alphas[time_step] From 57243fd56537dc7215360ab5b6db882b2a7b7fe5 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 12:37:28 +0200 Subject: [PATCH 7/9] GLIDE integration test --- src/diffusers/configuration_utils.py | 3 --- src/diffusers/pipelines/pipeline_glide.py | 16 ++++++++-------- tests/test_modeling_utils.py | 17 ++++++++++++++++- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 61a80ff1e2..5ba5ddec28 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -225,9 +225,6 @@ class ConfigMixin: text = reader.read() return json.loads(text) - def __eq__(self, other): - return self.__dict__ == other.__dict__ - def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 1f6495e890..138ce9d2f2 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline): # 1. Sample gaussian noise batch_size = 2 # second image is empty for classifier-free guidance - image = self.text_noise_scheduler.sample_noise( - (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator - ) + image = torch.randn( + (batch_size, self.text_unet.in_channels, 64, 64), generator=generator + ).to(torch_device) # 2. Encode tokens - # an empty input is needed to guide the model away from ( + # an empty input is needed to guide the model away from it inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") input_ids = inputs["input_ids"].to(torch_device) attention_mask = inputs["attention_mask"].to(torch_device) @@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline): mean, variance, log_variance, pred_xstart = self.p_mean_variance( text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out ) - noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(torch_device) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise @@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline): self.upscale_unet.resolution, ), generator=generator, - ) - image = image.to(torch_device) * upsample_temp + ).to(torch_device) + image = image * upsample_temp num_trained_timesteps = self.upscale_noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) @@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) + noise = torch.randn(image.shape, generator=generator).to(torch_device) variance = ( self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 417ef353d6..6db8831626 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel +from diffusers import DDIM, DDPM, PNDM, GLIDE, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -212,3 +212,18 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_glide_text2img(self): + model_id = "fusing/glide-base" + glide = GLIDE.from_pretrained(model_id) + + prompt = "a pencil sketch of a corgi" + generator = torch.manual_seed(0) + image = glide(prompt, generator=generator, num_inference_steps_upscale=20) + + image_slice = image[0, :3, :3, -1].cpu() + + assert image.shape == (1, 256, 256, 3) + expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 From df64f624c044e18071f178787b67e50f47c57028 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:39:21 +0000 Subject: [PATCH 8/9] finish pndm --- src/diffusers/schedulers/scheduling_pndm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index fa1c9ca56d..cc27b52055 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -96,6 +96,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] def step_warm_up(self, residual, image, t, num_inference_steps): + # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) t_prev = warmup_time_steps[t // 4 * 4] From da1f920ef124d00c5e81ba423e9d45e8783e9841 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:50:05 +0000 Subject: [PATCH 9/9] finalize pndm --- src/diffusers/pipelines/pipeline_pndm.py | 11 ++++------- src/diffusers/schedulers/scheduling_pndm.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index 1116b6042a..93d735a8a8 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline): self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): - # eta corresponds to η in paper and should be between [0, 1] + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline): image = image.to(torch_device) warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) - prev_image = image for t in tqdm.tqdm(range(len(warmup_time_steps))): t_orig = warmup_time_steps[t] residual = self.unet(image, t_orig) - if t % 4 == 0: - prev_image = image - - image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps) + image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) for t in tqdm.tqdm(range(len(timesteps))): t_orig = timesteps[t] residual = self.unet(image, t_orig) - image = self.noise_scheduler.step(residual, image, t, num_inference_steps) + image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) return image diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index cc27b52055..85fa6fb2f5 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # for now we only support F-PNDM, i.e. the runge-kutta method + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at equations (12) and (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_residual = 0 + self.cur_image = None self.ets = [] self.warmup_time_steps = {} self.time_steps = {} @@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] - def step_warm_up(self, residual, image, t, num_inference_steps): + def step_prk(self, residual, image, t, num_inference_steps): # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) @@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): if t % 4 == 0: self.cur_residual += 1 / 6 * residual self.ets.append(residual) + self.cur_image = image elif (t - 1) % 4 == 0: self.cur_residual += 1 / 3 * residual elif (t - 2) % 4 == 0: @@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): residual = self.cur_residual + 1 / 6 * residual self.cur_residual = 0 - return self.transfer(image, t_prev, t_next, residual) + return self.transfer(self.cur_image, t_prev, t_next, residual) - def step(self, residual, image, t, num_inference_steps): + def step_plms(self, residual, image, t, num_inference_steps): timesteps = self.get_time_steps(num_inference_steps) t_prev = timesteps[t]