From 9e31c6a7498e025efa9a8383b2eb71fae9502993 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 21 Jun 2022 14:07:58 +0200 Subject: [PATCH] refactor GLIDE text2im pipeline, remove classifier_free_guidance --- examples/train_unconditional.py | 2 +- src/diffusers/__init__.py | 1 - src/diffusers/optimization.py | 2 +- src/diffusers/pipeline_utils.py | 1 - src/diffusers/pipelines/pipeline_glide.py | 145 +++++------------- src/diffusers/schedulers/__init__.py | 1 - .../schedulers/classifier_free_guidance.py | 96 ------------ src/diffusers/schedulers/scheduling_ddpm.py | 14 +- 8 files changed, 54 insertions(+), 208 deletions(-) delete mode 100644 src/diffusers/schedulers/classifier_free_guidance.py diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 476ae95293..fbdc0aba29 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -10,6 +10,7 @@ from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.modeling_utils import unwrap_model +from diffusers.optimization import get_scheduler from diffusers.utils import logging from torchvision.transforms import ( CenterCrop, @@ -21,7 +22,6 @@ from torchvision.transforms import ( ToTensor, ) from tqdm.auto import tqdm -from diffusers.optimization import get_scheduler logger = logging.get_logger(__name__) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 22bc3b022a..881d48240e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet from .pipeline_utils import DiffusionPipeline from .pipelines import BDDM, DDIM, DDPM, PNDM from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin -from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler if is_transformers_available(): diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 70101aec81..84712bf809 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -273,4 +273,4 @@ def get_scheduler( if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) \ No newline at end of file + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index f19e07de28..d8a2644dc9 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -36,7 +36,6 @@ LOADABLE_CLASSES = { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_config", "from_config"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 50dc73730f..07603e153e 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..pipeline_utils import DiffusionPipeline -from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler +from ..schedulers import DDIMScheduler, DDPMScheduler from ..utils import logging @@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline): def __init__( self, text_unet: GLIDETextToImageUNetModel, - text_noise_scheduler: ClassifierFreeGuidanceScheduler, + text_noise_scheduler: DDPMScheduler, text_encoder: CLIPTextModel, tokenizer: GPT2Tokenizer, upscale_unet: GLIDESuperResUNetModel, @@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline): upscale_noise_scheduler=upscale_noise_scheduler, ) - def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior: - - q(x_{t-1} | x_t, x_0) - - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape) - assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] - ) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True): - """ - Apply the model to get p(x_{t-1} | x_t), as well as a prediction of - the initial x, x_0. - - :param model: the model, which takes a signal and a batch of timesteps - as input. - :param x: the [N x C x ...] tensor at time t. - :param t: a 1-D Tensor of timesteps. - :param clip_denoised: if True, clip the denoised signal into [-1, 1]. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict with the following keys: - - 'mean': the model mean output. - - 'variance': the model variance output. - - 'log_variance': the log of 'variance'. - - 'pred_xstart': the prediction for x_0. - """ - - B, C = x.shape[:2] - assert t.shape == (B,) - if transformer_out is None: - # super-res model - model_output = model(x, t, low_res) - else: - # text2image model - model_output = model(x, t, transformer_out) - - assert model_output.shape == (B, C * 2, *x.shape[2:]) - model_output, model_var_values = torch.split(model_output, C, dim=1) - min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape) - max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape) - # The model_var_values is [-1, 1] for [min_var, max_var]. - frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log - model_variance = torch.exp(model_log_variance) - - pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output) - if clip_denoised: - pred_xstart = pred_xstart.clamp(-1, 1) - model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t) - - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - return model_mean, model_variance, model_log_variance, pred_xstart - - def _predict_xstart_from_eps(self, scheduler, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - ) - - def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart): - return ( - _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - @torch.no_grad() - def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50): + def __call__( + self, + prompt, + generator=None, + torch_device=None, + num_inference_steps_upscale=50, + guidance_scale=3.0, + eta=0.0, + upsample_temp=0.997, + ): + torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.text_unet.to(torch_device) self.text_encoder.to(torch_device) self.upscale_unet.to(torch_device) - # Create a classifier-free guidance sampling function - guidance_scale = 3.0 - - def text_model_fn(x_t, ts, transformer_out, **kwargs): + def text_model_fn(x_t, timesteps, transformer_out, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) - model_out = self.text_unet(combined, ts, transformer_out, **kwargs) + model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) @@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline): # 1. Sample gaussian noise batch_size = 2 # second image is empty for classifier-free guidance - image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device) + image = torch.randn( + ( + batch_size, + self.text_unet.in_channels, + self.text_unet.resolution, + self.text_unet.resolution, + ), + generator=generator, + ).to(torch_device) # 2. Encode tokens # an empty input is needed to guide the model away from it @@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline): transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state # 3. Run the text2image generation step - num_timesteps = len(self.text_noise_scheduler) - for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): - t = torch.tensor([i] * image.shape[0], device=torch_device) - mean, variance, log_variance, pred_xstart = self.p_mean_variance( - text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out - ) + num_prediction_steps = len(self.text_noise_scheduler) + for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): + with torch.no_grad(): + time_input = torch.tensor([t] * image.shape[0], device=torch_device) + model_output = text_model_fn(image, time_input, transformer_out) + noise_residual, model_var_values = torch.split(model_output, 3, dim=1) + + min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log") + max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log") + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + + pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t) noise = torch.randn(image.shape, generator=generator).to(torch_device) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 - image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise + variance = torch.exp(0.5 * model_log_variance) * noise + + # set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # 4. Run the upscaling step batch_size = 1 image = image[:1] low_res = ((image + 1) * 127.5).round() / 127.5 - 1 - eta = 0.0 - - # Tune this parameter to control the sharpness of 256x256 images. - # A value of 1.0 is sharper, but sometimes results in grainy artifacts. - upsample_temp = 0.997 # Sample gaussian noise to begin loop image = torch.randn( @@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline): num_trained_timesteps = self.upscale_noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) - # adapt the beta schedule to the number of steps - # self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale) for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): # 1. predict noise residual diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 47e5f6a1db..b2d533d380 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -16,7 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py deleted file mode 100644 index f5a0b2d942..0000000000 --- a/src/diffusers/schedulers/classifier_free_guidance.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import numpy as np -import torch -from torch import nn - -from ..configuration_utils import ConfigMixin - - -SAMPLING_CONFIG_NAME = "scheduler_config.json" - - -def linear_beta_schedule(timesteps, beta_start, beta_end): - return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float64) - - -class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): - - config_name = SAMPLING_CONFIG_NAME - - def __init__( - self, - timesteps=1000, - beta_schedule="squaredcos_cap_v2", - ): - super().__init__() - self.register_to_config( - timesteps=timesteps, - beta_schedule=beta_schedule, - ) - - if beta_schedule == "squaredcos_cap_v2": - # GLIDE cosine schedule - self.betas = betas_for_alpha_bar( - timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) - self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = np.log( - np.append(self.posterior_variance[1], self.posterior_variance[1:]) - ) - self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) - - def sample_noise(self, shape, device, generator=None): - # always sample on CPU to be deterministic - return torch.randn(shape, generator=generator).to(device) - - def __len__(self): - return self.config.timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index bc95c0afa8..eb85796f27 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - def get_variance(self, t): + def get_variance(self, t, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one @@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + if variance_type is None: + variance_type = self.config.variance_type + # hacks - were probs added for training stability - if self.config.variance_type == "fixed_small": + if variance_type == "fixed_small": variance = self.clip(variance, min_value=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 - elif self.config.variance_type == "fixed_small_log": + elif variance_type == "fixed_small_log": variance = self.log(self.clip(variance, min_value=1e-20)) - elif self.config.variance_type == "fixed_large": + elif variance_type == "fixed_large": variance = self.betas[t] + elif variance_type == "fixed_large_log": + # GLIDE max_log + variance = self.log(self.betas[t]) return variance