From df64f624c044e18071f178787b67e50f47c57028 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:39:21 +0000 Subject: [PATCH 1/5] finish pndm --- src/diffusers/schedulers/scheduling_pndm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index fa1c9ca56d..cc27b52055 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -96,6 +96,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] def step_warm_up(self, residual, image, t, num_inference_steps): + # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) t_prev = warmup_time_steps[t // 4 * 4] From da1f920ef124d00c5e81ba423e9d45e8783e9841 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:50:05 +0000 Subject: [PATCH 2/5] finalize pndm --- src/diffusers/pipelines/pipeline_pndm.py | 11 ++++------- src/diffusers/schedulers/scheduling_pndm.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index 1116b6042a..93d735a8a8 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline): self.register_modules(unet=unet, noise_scheduler=noise_scheduler) def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): - # eta corresponds to η in paper and should be between [0, 1] + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://arxiv.org/pdf/2202.09778.pdf if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline): image = image.to(torch_device) warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) - prev_image = image for t in tqdm.tqdm(range(len(warmup_time_steps))): t_orig = warmup_time_steps[t] residual = self.unet(image, t_orig) - if t % 4 == 0: - prev_image = image - - image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps) + image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) for t in tqdm.tqdm(range(len(timesteps))): t_orig = timesteps[t] residual = self.unet(image, t_orig) - image = self.noise_scheduler.step(residual, image, t, num_inference_steps) + image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) return image diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index cc27b52055..85fa6fb2f5 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # for now we only support F-PNDM, i.e. the runge-kutta method + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at equations (12) and (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_residual = 0 + self.cur_image = None self.ets = [] self.warmup_time_steps = {} self.time_steps = {} @@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] - def step_warm_up(self, residual, image, t, num_inference_steps): + def step_prk(self, residual, image, t, num_inference_steps): # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) @@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): if t % 4 == 0: self.cur_residual += 1 / 6 * residual self.ets.append(residual) + self.cur_image = image elif (t - 1) % 4 == 0: self.cur_residual += 1 / 3 * residual elif (t - 2) % 4 == 0: @@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): residual = self.cur_residual + 1 / 6 * residual self.cur_residual = 0 - return self.transfer(image, t_prev, t_next, residual) + return self.transfer(self.cur_image, t_prev, t_next, residual) - def step(self, residual, image, t, num_inference_steps): + def step_plms(self, residual, image, t, num_inference_steps): timesteps = self.get_time_steps(num_inference_steps) t_prev = timesteps[t] From d81b56ba5ce2ee13cd18131a4367080d78304e4d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 14 Jun 2022 12:50:27 +0200 Subject: [PATCH 3/5] allow loading model from pipeline module --- src/diffusers/pipeline_utils.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 743a807510..77be534009 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin): config_name = "model_index.json" def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + for name, module in kwargs.items(): + # check if the module is a pipeline module + is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1]) + # retrive library library = module.__module__.split(".")[0] - # if library is not in LOADABLE_CLASSES, then it is a custom module - if library not in LOADABLE_CLASSES: + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: library = module.__module__.split(".")[-1] # retrive class_name @@ -151,12 +160,22 @@ class DiffusionPipeline(ConfigMixin): init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} - + + # import it here to avoid circular import + from diffusers import pipelines + # 4. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): - # if the model is not in diffusers or transformers, we need to load it from the hub - # assumes that it's a subclass of ModelMixin - if library_name == module_candidate_name: + is_pipeline_module = hasattr(pipelines, library_name) + # if the model is in a pipeline module, then we load it from the pipeline + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} + elif library_name == module_candidate_name: + # if the model is not in diffusers or transformers, we need to load it from the hub + # assumes that it's a subclass of ModelMixin class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder) # since it's not from a library, we need to check class candidates for all importable classes importable_classes = ALL_IMPORTABLE_CLASSES From 147d8e07029700e49a66991e6263fd2c39fd5fec Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 14 Jun 2022 12:50:40 +0200 Subject: [PATCH 4/5] add test for loading model from pipeline module --- tests/test_modeling_utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c119479fa..cacc356530 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,9 +19,10 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler +from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + def test_module_from_pipeline(self): + model = DiffWave(num_res_layers=4) + noise_scheduler = DDPMScheduler(timesteps=12) + + bddm = BDDM(model, noise_scheduler) + + # check if the library name for the diffwave moduel is set to pipeline module + self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") + + # check if we can save and load the pipeline + with tempfile.TemporaryDirectory() as tmpdirname: + bddm.save_pretrained(tmpdirname) + _ = BDDM.from_pretrained(tmpdirname) + # check if the same works using the DifusionPipeline class + _ = DiffusionPipeline.from_pretrained(tmpdirname) \ No newline at end of file From be736cb24856b0b487096681e45995417cc9d8fd Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 14 Jun 2022 15:31:36 +0200 Subject: [PATCH 5/5] delete unused files --- _ | 156 ---------------------------------------------------- run_pndm.py | 27 --------- 2 files changed, 183 deletions(-) delete mode 100644 _ delete mode 100755 run_pndm.py diff --git a/_ b/_ deleted file mode 100644 index 702652c8aa..0000000000 --- a/_ +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - - -import torch - -import tqdm - -from ..pipeline_utils import DiffusionPipeline - - -class PNDM(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): - super().__init__() - noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) - - def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): - # eta corresponds to η in paper and should be between [0, 1] - if torch_device is None: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" - - num_trained_timesteps = self.noise_scheduler.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - - self.unet.to(torch_device) - - # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), - generator=generator, - ) - image = image.to(torch_device) - - seq = list(inference_step_times) - seq_next = [-1] + list(seq[:-1]) - model = self.unet - - warmup_steps = [len(seq) - (i // 4 + 1) for i in range(3 * 4)] - - ets = [] - prev_image = image - for i, step_idx in enumerate(warmup_steps): - i = seq[step_idx] - j = seq_next[step_idx] - - t = (torch.ones(image.shape[0]) * i) - t_next = (torch.ones(image.shape[0]) * j) - - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - - image = image.to("cpu") - image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_list[0], t_list[1], residual) - - if i % 4 == 0: - ets.append(residual) - prev_image = image - - for - - ets = [] - step_idx = len(seq) - 1 - while step_idx >= 0: - i = seq[step_idx] - j = seq_next[step_idx] - - t = (torch.ones(image.shape[0]) * i) - t_next = (torch.ones(image.shape[0]) * j) - - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - - t_list = [t, (t+t_next)/2, t_next] - - ets.append(residual) - if len(ets) <= 3: - image = image.to("cpu") - x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual) - - e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu") - x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2) - e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu") - x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3) - e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu") - residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4) - else: - residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) - - img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) - image = img_next - - step_idx = step_idx - 1 - -# if len(prev_noises) in [1, 2]: -# t = (t + t_next) / 2 -# elif len(prev_noises) == 3: -# t = t_next / 2 - -# if len(prev_noises) == 0: -# ets.append(residual) -# -# if len(ets) > 3: -# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) -# step_idx = step_idx - 1 -# elif len(ets) <= 3 and len(prev_noises) == 3: -# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual) -# prev_noises = [] -# step_idx = step_idx - 1 -# elif len(ets) <= 3 and len(prev_noises) < 3: -# prev_noises.append(residual) -# if len(prev_noises) < 2: -# t_next = (t + t_next) / 2 -# -# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) - - return image - - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" -# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - # 1. predict noise residual -# with torch.no_grad(): -# residual = self.unet(image, inference_step_times[t]) -# - # 2. predict previous mean of image x_t-1 -# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) -# - # 3. optionally sample variance -# variance = 0 -# if eta > 0: -# noise = torch.randn(image.shape, generator=generator).to(image.device) -# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise -# - # 4. set current image to prev_image: x_t -> x_t-1 -# image = pred_prev_image + variance diff --git a/run_pndm.py b/run_pndm.py deleted file mode 100755 index 6ef17bff33..0000000000 --- a/run_pndm.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -from diffusers import PNDM, UNetModel, PNDMScheduler -import PIL.Image -import numpy as np -import torch - -model_id = "fusing/ddim-celeba-hq" - -model = UNetModel.from_pretrained(model_id) -scheduler = PNDMScheduler() - -# load model and scheduler -ddpm = PNDM(unet=model, noise_scheduler=scheduler) - -# run pipeline in inference (sample random noise and denoise) -image = ddpm() - -# process image to PIL -image_processed = image.cpu().permute(0, 2, 3, 1) -image_processed = (image_processed + 1.0) / 2 -image_processed = torch.clamp(image_processed, 0.0, 1.0) -image_processed = image_processed * 255 -image_processed = image_processed.numpy().astype(np.uint8) -image_pil = PIL.Image.fromarray(image_processed[0]) - -# save image -image_pil.save("/home/patrick/images/test.png")