diff --git a/README.md b/README.md index 10d3655b21..7c0a2fe71f 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` -**Text to Image generation with Latent Diffusion** +#### **Text to Image generation with Latent Diffusion** ```python from diffusers import DiffusionPipeline @@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0]) # save image image_pil.save("test.png") +``` + + #### **Text to speech with BDDM** + +_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._ + +```python +import torch +from diffusers import BDDM, DiffusionPipeline + +torch_device = "cuda" + +# load the BDDM pipeline +bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder") + +# load tacotron2 to get the mel spectograms +tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16') +tacotron2 = tacotron2.to(torch_device).eval() + +text = "Hello world, I missed you so much." + +utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils') +sequences, lengths = utils.prepare_input_sequence([text]) + +# generate mel spectograms using text +with torch.no_grad(): + mel_spec, _, _ = tacotron2.infer(sequences, lengths) + +# generate the speech by passing mel spectograms to BDDM pipeline +generator = torch.manual_seed(0) +audio = bddm(mel_spec, generator, torch_device) + +# save generated audio +from scipy.io.wavfile import write as wavwrite +sampling_rate = 22050 +wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) ``` ## Library structure: ``` -├── models -│   ├── audio -│   │   └── fastdiff -│   │   ├── modeling_fastdiff.py -│   │   ├── README.md -│   │   └── run_fastdiff.py -│   ├── __init__.py -│   └── vision -│   ├── dalle2 -│   │   ├── modeling_dalle2.py -│   │   ├── README.md -│   │   └── run_dalle2.py -│   ├── ddpm -│   │   ├── example.py -│   │   ├── modeling_ddpm.py -│   │   ├── README.md -│   │   └── run_ddpm.py -│   ├── glide -│   │   ├── modeling_glide.py -│   │   ├── modeling_vqvae.py.py -│   │   ├── README.md -│   │   └── run_glide.py -│   ├── imagen -│   │   ├── modeling_dalle2.py -│   │   ├── README.md -│   │   └── run_dalle2.py -│   ├── __init__.py -│   └── latent_diffusion -│   ├── modeling_latent_diffusion.py -│   ├── README.md -│   └── run_latent_diffusion.py -├── pyproject.toml +├── LICENSE +├── Makefile ├── README.md +├── pyproject.toml ├── setup.cfg ├── setup.py ├── src -│   └── diffusers -│   ├── configuration_utils.py -│   ├── __init__.py -│   ├── modeling_utils.py -│   ├── models -│   │   ├── __init__.py -│   │   ├── unet_glide.py -│   │   └── unet.py -│   ├── pipeline_utils.py -│   └── schedulers -│   ├── gaussian_ddpm.py -│   ├── __init__.py +│ ├── diffusers +│ ├── __init__.py +│ ├── configuration_utils.py +│ ├── dependency_versions_check.py +│ ├── dependency_versions_table.py +│ ├── dynamic_modules_utils.py +│ ├── modeling_utils.py +│ ├── models +│ │ ├── __init__.py +│ │ ├── unet.py +│ │ ├── unet_glide.py +│ │ └── unet_ldm.py +│ ├── pipeline_utils.py +│ ├── pipelines +│ │ ├── __init__.py +│ │ ├── configuration_ldmbert.py +│ │ ├── conversion_glide.py +│ │ ├── modeling_vae.py +│ │ ├── pipeline_bddm.py +│ │ ├── pipeline_ddim.py +│ │ ├── pipeline_ddpm.py +│ │ ├── pipeline_glide.py +│ │ └── pipeline_latent_diffusion.py +│ ├── schedulers +│ │ ├── __init__.py +│ │ ├── classifier_free_guidance.py +│ │ ├── scheduling_ddim.py +│ │ ├── scheduling_ddpm.py +│ │ ├── scheduling_plms.py +│ │ └── scheduling_utils.py +│ ├── testing_utils.py +│ └── utils +│ ├── __init__.py +│ └── logging.py ├── tests -│   └── test_modeling_utils.py +│ ├── __init__.py +│ ├── test_modeling_utils.py +│ └── test_scheduler.py +└── utils + ├── check_config_docstrings.py + ├── check_copies.py + ├── check_dummies.py + ├── check_inits.py + ├── check_repo.py + ├── check_table.py + └── check_tf_ops.py ``` diff --git a/run_pndm.py b/run_pndm.py new file mode 100755 index 0000000000..6ef17bff33 --- /dev/null +++ b/run_pndm.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +from diffusers import PNDM, UNetModel, PNDMScheduler +import PIL.Image +import numpy as np +import torch + +model_id = "fusing/ddim-celeba-hq" + +model = UNetModel.from_pretrained(model_id) +scheduler = PNDMScheduler() + +# load model and scheduler +ddpm = PNDM(unet=model, noise_scheduler=scheduler) + +# run pipeline in inference (sample random noise and denoise) +image = ddpm() + +# process image to PIL +image_processed = image.cpu().permute(0, 2, 3, 1) +image_processed = (image_processed + 1.0) / 2 +image_processed = torch.clamp(image_processed, 0.0, 1.0) +image_processed = image_processed * 255 +image_processed = image_processed.numpy().astype(np.uint8) +image_pil = PIL.Image.fromarray(image_processed[0]) + +# save image +image_pil.save("/home/patrick/images/test.png") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c7127aa8e9..e374e3aed2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,6 +9,6 @@ from .models.unet import UNetModel from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM -from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin +from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM +from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ad42aead20..e0d2bf2e30 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,5 +1,6 @@ from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM +from .pipeline_pndm import PNDM from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_bddm import BDDM diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py new file mode 100644 index 0000000000..b9a04d98f9 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -0,0 +1,110 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +import torch + +import tqdm + +from ..pipeline_utils import DiffusionPipeline + + +class PNDM(DiffusionPipeline): + def __init__(self, unet, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + + def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): + # eta corresponds to η in paper and should be between [0, 1] + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + num_trained_timesteps = self.noise_scheduler.timesteps + inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) + + self.unet.to(torch_device) + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + generator=generator, + ) + image = image.to(torch_device) + + seq = inference_step_times + seq_next = [-1] + list(seq[:-1]) + model = self.unet + + ets = [] + for i, j in zip(reversed(seq), reversed(seq_next)): + t = (torch.ones(image.shape[0]) * i) + t_next = (torch.ones(image.shape[0]) * j) + + residual = model(image.to("cuda"), t.to("cuda")) + residual = residual.to("cpu") + + t_list = [t, (t+t_next)/2, t_next] + + if len(ets) <= 2: + ets.append(residual) + image = image.to("cpu") + x_2 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], residual) + e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu") + x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2) + e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu") + x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3) + e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu") + residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4) + else: + ets.append(residual) + residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) + + img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) + +# with torch.no_grad(): +# t_start, t_end = t_next, t +# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets) + + image = img_next + + return image + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_image -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_image_direction -> "direction pointingc to x_t" + # - pred_prev_image -> "x_t-1" +# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + # 1. predict noise residual +# with torch.no_grad(): +# residual = self.unet(image, inference_step_times[t]) +# + # 2. predict previous mean of image x_t-1 +# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) +# + # 3. optionally sample variance +# variance = 0 +# if eta > 0: +# noise = torch.randn(image.shape, generator=generator).to(image.device) +# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise +# + # 4. set current image to prev_image: x_t -> x_t-1 +# image = pred_prev_image + variance diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 3b1cf92c77..5e9dcaf64e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,4 +19,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler +from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_plms.py b/src/diffusers/schedulers/scheduling_plms.py deleted file mode 100644 index fd9809e5b5..0000000000 --- a/src/diffusers/schedulers/scheduling_plms.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import numpy as np -import torch - -from tqdm import tqdm - -from ..configuration_utils import ConfigMixin -from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == "quad": - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) - else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f"Selected timesteps for ddim sampler: {steps_out}") - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) - if verbose: - print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") - print( - f"For the chosen value of eta, which is {eta}, " - f"this results in the following sigma_t schedule for ddim sampler {sigmas}" - ) - return sigmas, alphas, alphas_prev - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): - if ddim_eta != 0: - raise ValueError("ddim_eta must be 0 for PLMS") - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) - self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) - self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) - self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f"Data shape for PLMS sampling is {size}") - - samples, intermediates = self.plms_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample_plms( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - ) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, - ): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - ) - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py new file mode 100644 index 0000000000..d69e4a5188 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -0,0 +1,138 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule + + +class PNDMScheduler(SchedulerMixin, ConfigMixin): + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + tensor_format="np", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + beta_schedule=beta_schedule, + ) + self.timesteps = int(timesteps) + + if beta_schedule == "linear": + self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + elif beta_schedule == "squaredcos_cap_v2": + # GLIDE cosine schedule + self.betas = betas_for_alpha_bar( + timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.one = np.array(1.0) + + self.set_format(tensor_format=tensor_format) + + # self.register_buffer("betas", betas.to(torch.float32)) + # self.register_buffer("alphas", alphas.to(torch.float32)) + # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + + # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + # TODO(PVP) - check how much of these is actually necessary! + # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... + # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 + # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + # if variance_type == "fixed_small": + # log_variance = torch.log(variance.clamp(min=1e-20)) + # elif variance_type == "fixed_large": + # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) + # + # + # self.register_buffer("log_variance", log_variance.to(torch.float32)) + + def get_alpha(self, time_step): + return self.alphas[time_step] + + def get_beta(self, time_step): + return self.betas[time_step] + + def get_alpha_prod(self, time_step): + if time_step < 0: + return self.one + return self.alphas_cumprod[time_step] + + def step(self, img, t_start, t_end, model, ets): +# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets) +#def gen_order_4(img, t, t_next, model, alphas_cump, ets): + t_next, t = t_start, t_end + + noise_ = model(img.to("cuda"), t.to("cuda")) + noise_ = noise_.to("cpu") + + t_list = [t, (t+t_next)/2, t_next] + if len(ets) > 2: + ets.append(noise_) + noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) + else: + noise = self.runge_kutta(img, t_list, model, ets, noise_) + + img_next = self.transfer(img.to("cpu"), t, t_next, noise) + return img_next, ets + + def runge_kutta(self, x, t_list, model, ets, noise_): + model = model.to("cuda") + x = x.to("cpu") + + e_1 = noise_ + ets.append(e_1) + x_2 = self.transfer(x, t_list[0], t_list[1], e_1) + + e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")) + e_2 = e_2.to("cpu") + x_3 = self.transfer(x, t_list[0], t_list[1], e_2) + + e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")) + e_3 = e_3.to("cpu") + x_4 = self.transfer(x, t_list[0], t_list[2], e_3) + + e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")) + e_4 = e_4.to("cpu") + + et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) + + return et + + def transfer(self, x, t, t_next, et): + alphas_cump = self.alphas_cumprod + at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) + at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) + + x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) + + x_next = x + x_delta + return x_next + + def __len__(self): + return self.timesteps diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 9382736d01..6c119479fa 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel +from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase): ) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_pndm_cifar10(self): + generator = torch.manual_seed(0) + model_id = "fusing/ddpm-cifar10" + + unet = UNetModel.from_pretrained(model_id) + noise_scheduler = PNDMScheduler(tensor_format="pt") + + pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler) + image = pndm(generator=generator) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow def test_ldm_text2img(self): model_id = "fusing/latent-diffusion-text2im-large"