From 418888a5665213c0921a68c98463be62754badb7 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 08:00:23 +0200 Subject: [PATCH 1/3] Pokemon DDPM training --- src/diffusers/trainers/training_ddpm.py | 39 +++++++++++++------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/diffusers/trainers/training_ddpm.py b/src/diffusers/trainers/training_ddpm.py index 6753a580e9..bc2a4d10ba 100644 --- a/src/diffusers/trainers/training_ddpm.py +++ b/src/diffusers/trainers/training_ddpm.py @@ -8,14 +8,14 @@ import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel -from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor +from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup def set_seed(seed): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = False torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) @@ -30,13 +30,13 @@ model = UNetModel( attn_resolutions=(16,), ch=128, ch_mult=(1, 2, 2, 2), - dropout=0.1, + dropout=0.0, num_res_blocks=2, resamp_with_conv=True, resolution=32 ) noise_scheduler = DDPMScheduler(timesteps=1000) -optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) +optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) num_epochs = 100 batch_size = 64 @@ -44,9 +44,10 @@ gradient_accumulation_steps = 2 augmentations = Compose( [ - Resize(32), - CenterCrop(32), RandomHorizontalFlip(), + RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1), + Resize(32, interpolation=InterpolationMode.BILINEAR), + CenterCrop(32), ToTensor(), Lambda(lambda x: x * 2 - 1), ] @@ -59,24 +60,24 @@ def transforms(examples): return {"input": images} -dataset = dataset.shuffle(seed=0) dataset.set_transform(transforms) -train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) -#lr_scheduler = get_linear_schedule_with_warmup( -# optimizer=optimizer, -# num_warmup_steps=1000, -# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, -#) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=500, + num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, +) -model, optimizer, train_dataloader = accelerator.prepare( - model, optimizer, train_dataloader +model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler ) for epoch in range(num_epochs): model.train() pbar = tqdm(total=len(train_dataloader), unit="ba") pbar.set_description(f"Epoch {epoch}") + losses = [] for step, batch in enumerate(train_dataloader): clean_images = batch["input"] noisy_images = torch.empty_like(clean_images) @@ -101,10 +102,12 @@ for epoch in range(num_epochs): accelerator.backward(loss) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - # lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad() + loss = loss.detach().item() + losses.append(loss) pbar.update(1) - pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"]) optimizer.step() From bb3066428537da6263676448e737f315203d986c Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 11:33:24 +0200 Subject: [PATCH 2/3] Move the training example --- Makefile | 2 +- .../trainers => examples}/training_ddpm.py | 29 ++++++++++++------- tests/test_modeling_utils.py | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) rename {src/diffusers/trainers => examples}/training_ddpm.py (87%) diff --git a/Makefile b/Makefile index dad0611769..ddf143b6d4 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := tests src utils +check_dirs := examples tests src utils modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/src/diffusers/trainers/training_ddpm.py b/examples/training_ddpm.py similarity index 87% rename from src/diffusers/trainers/training_ddpm.py rename to examples/training_ddpm.py index bc2a4d10ba..b3ba111ccb 100644 --- a/src/diffusers/trainers/training_ddpm.py +++ b/examples/training_ddpm.py @@ -8,14 +8,23 @@ import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel -from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor +from torchvision.transforms import ( + Compose, + InterpolationMode, + Lambda, + RandomCrop, + RandomHorizontalFlip, + RandomVerticalFlip, + Resize, + ToTensor, +) from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup def set_seed(seed): - #torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) @@ -33,7 +42,7 @@ model = UNetModel( dropout=0.0, num_res_blocks=2, resamp_with_conv=True, - resolution=32 + resolution=32, ) noise_scheduler = DDPMScheduler(timesteps=1000) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) @@ -44,15 +53,15 @@ gradient_accumulation_steps = 2 augmentations = Compose( [ - RandomHorizontalFlip(), - RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1), Resize(32, interpolation=InterpolationMode.BILINEAR), - CenterCrop(32), + RandomHorizontalFlip(), + RandomVerticalFlip(), + RandomCrop(32), ToTensor(), Lambda(lambda x: x * 2 - 1), ] ) -dataset = load_dataset("huggan/pokemon", split="train") +dataset = load_dataset("huggan/flowers-102-categories", split="train") def transforms(examples): @@ -127,5 +136,5 @@ for epoch in range(num_epochs): image_pil = PIL.Image.fromarray(image_processed[0]) # save image - pipeline.save_pretrained("./poke-ddpm") - image_pil.save(f"./poke-ddpm/test_{epoch}.png") + pipeline.save_pretrained("./flowers-ddpm") + image_pil.save(f"./flowers-ddpm/test_{epoch}.png") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c119479fa..417ef353d6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler +from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device From d10441d877a84747d3f8e946b536107050e33f20 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 14 Jun 2022 11:43:05 +0200 Subject: [PATCH 3/3] Revert config eq --- src/diffusers/configuration_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 4436445334..61a80ff1e2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -225,11 +225,11 @@ class ConfigMixin: text = reader.read() return json.loads(text) - # def __eq__(self, other): - # return self.__dict__ == other.__dict__ + def __eq__(self, other): + return self.__dict__ == other.__dict__ - # def __repr__(self): - # return f"{self.__class__.__name__} {self.to_json_string()}" + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: