diff --git a/models/vision/ddim/README.md b/models/vision/ddim/README.md new file mode 100644 index 0000000000..1070c04230 --- /dev/null +++ b/models/vision/ddim/README.md @@ -0,0 +1,28 @@ + + +# Denoising Diffusion Implicit Models (DDIM) + +## Overview + +DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon* + +The abstract from the paper is the following: + +*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.* + +Tips: + +- ... +- ... + +This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion). diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py new file mode 100644 index 0000000000..dcd084c034 --- /dev/null +++ b/models/vision/ddim/modeling_ddim.py @@ -0,0 +1,67 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from diffusers import DiffusionPipeline +import tqdm +import torch + + +def compute_alpha(beta, t): + beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) + a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) + return a + + +class DDIM(DiffusionPipeline): + + def __init__(self, unet, noise_scheduler): + super().__init__() + self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + + def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50): + seq = range(0, self.num_timesteps, self.num_timesteps // inference_time_steps) + b = self.noise_scheduler.betas + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.unet.to(torch_device) + x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) + + with torch.no_grad(): + n = batch_size + seq_next = [-1] + list(seq[:-1]) + x0_preds = [] + xs = [x] + for i, j in zip(reversed(seq), reversed(seq_next)): + print(i) + t = (torch.ones(n) * i).to(x.device) + next_t = (torch.ones(n) * j).to(x.device) + at = compute_alpha(b, t.long()) + at_next = compute_alpha(b, next_t.long()) + xt = xs[-1].to('cuda') + et = self.unet(xt, t) + x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() + x0_preds.append(x0_t.to('cpu')) + # eta + c1 = ( + eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() + ) + c2 = ((1 - at_next) - c1 ** 2).sqrt() + xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et + xs.append(xt_next.to('cpu')) + + import ipdb; ipdb.set_trace() + return xs, x0_preds diff --git a/models/vision/ddim/run_ddpm.py b/models/vision/ddim/run_ddpm.py new file mode 100755 index 0000000000..88de931381 --- /dev/null +++ b/models/vision/ddim/run_ddpm.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import GaussianDDPMScheduler, UNetModel + + +model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) + +diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 + +training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 +loss = diffusion(training_images) +loss.backward() +# after a lot of training + +sampled_images = diffusion.sample(batch_size=4) +sampled_images.shape # (4, 3, 128, 128) diff --git a/run_inference.py b/models/vision/ddim/run_inference.py similarity index 86% rename from run_inference.py rename to models/vision/ddim/run_inference.py index 38cdd3bb7d..59ed5865b2 100755 --- a/run_inference.py +++ b/models/vision/ddim/run_inference.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # !pip install diffusers -from diffusers import DiffusionPipeline +from modeling_ddim import DDIM import PIL.Image import numpy as np @@ -8,7 +8,7 @@ model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-lsun-bedroom" # load model and scheduler -ddpm = DiffusionPipeline.from_pretrained(model_id) +ddpm = DDIM.from_pretrained(model_id) # run pipeline in inference (sample random noise and denoise) image = ddpm() diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index e85d3cfe50..f84ab452a5 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -21,8 +21,6 @@ import torch class DDPM(DiffusionPipeline): - modeling_file = "modeling_ddpm.py" - def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 60ece225ab..2d4803d356 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -55,14 +55,13 @@ class DiffusionPipeline(ConfigMixin): class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} - # save model index config self.register(**register_dict) # set models setattr(self, name, module) - + register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} self.register(**register_dict) @@ -101,15 +100,15 @@ class DiffusionPipeline(ConfigMixin): cached_folder = pretrained_model_name_or_path config_dict = cls.get_config_dict(cached_folder) - - module = config_dict["_module"] - class_name_ = config_dict["_class_name"] - - if class_name_ == cls.__name__: + + # if we load from explicit class, let's use it + if cls != DiffusionPipeline: pipeline_class = cls else: + # else we need to load the correct module from the Hub + class_name_ = config_dict["_class_name"] + module = config_dict["_module"] pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -120,6 +119,7 @@ class DiffusionPipeline(ConfigMixin): if library_name == module: # TODO(Suraj) + # for vq pass library = importlib.import_module(library_name)