From 1e21f061601dda0aa9740e88bfce68bf4aac4acd Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 8 Jun 2022 11:47:47 +0200 Subject: [PATCH] Classifier-free guidance scheduler + GLIDe pipeline --- models/vision/glide/README.md | 4 + models/vision/glide/convert_weights.py | 46 ++++-- models/vision/glide/modeling_glide.py | 144 ++++++++++++++---- models/vision/glide/run_glide.py | 9 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 1 + .../diffusers/models/clip_text_transformer.py | 0 src/diffusers/models/unet_glide.py | 10 +- src/diffusers/pipeline_utils.py | 4 +- src/diffusers/schedulers/__init__.py | 1 + .../schedulers/classifier_free_guidance.py | 102 +++++++++++++ 11 files changed, 275 insertions(+), 48 deletions(-) rename models/vision/glide/modelling_text_encoder.py => src/diffusers/models/clip_text_transformer.py (100%) create mode 100644 src/diffusers/schedulers/classifier_free_guidance.py diff --git a/models/vision/glide/README.md b/models/vision/glide/README.md index e69de29bb2..743c9bb6da 100644 --- a/models/vision/glide/README.md +++ b/models/vision/glide/README.md @@ -0,0 +1,4 @@ +# References + +[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf) +[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf) \ No newline at end of file diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py index 7ec1b92432..4f3320d7b2 100644 --- a/models/vision/glide/convert_weights.py +++ b/models/vision/glide/convert_weights.py @@ -1,25 +1,28 @@ -import argparse - import torch from torch import nn from transformers import CLIPTextConfig, GPT2Tokenizer -from modelling_text_encoder import CLIPTextModel +from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel +from modeling_glide import GLIDE # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt state_dict = torch.load("base.pt", map_location="cpu") state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} + +### Convert the text encoder + config = CLIPTextConfig( + vocab_size=50257, + max_position_embeddings=128, hidden_size=512, intermediate_size=2048, num_hidden_layers=16, num_attention_heads=8, - max_position_embeddings=128, use_padding_embeddings=True, ) model = CLIPTextModel(config).eval() tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") -tokenizer.save_pretrained("./glide-base") +#tokenizer.save_pretrained("./glide-base") hf_encoder = model.text_model @@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers): hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] -inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") -with torch.no_grad(): - outputs = model(**inputs) +#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") +#with torch.no_grad(): +# outputs = model(**inputs) -model.save_pretrained("./glide-base") \ No newline at end of file +#model.save_pretrained("./glide-base") + +### Convert the UNet + +unet_model = UNetGLIDEModel( + in_channels=3, + model_channels=192, + out_channels=6, + num_res_blocks=3, + attention_resolutions=(2, 4, 8), + dropout=0.1, + channel_mult=(1, 2, 3, 4), + num_heads=1, + num_head_channels=64, + num_heads_upsample=1, + use_scale_shift_norm=True, + resblock_updown=True, +) + +unet_model.load_state_dict(state_dict, strict=False) + +scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") + +glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) + +glide.save_pretrained("./glide-base") \ No newline at end of file diff --git a/models/vision/glide/modeling_glide.py b/models/vision/glide/modeling_glide.py index 747c173293..56c7b35f1d 100644 --- a/models/vision/glide/modeling_glide.py +++ b/models/vision/glide/modeling_glide.py @@ -14,46 +14,136 @@ # limitations under the License. -from diffusers import DiffusionPipeline -from diffusers import UNetGLIDEModel +from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel +from transformers import GPT2Tokenizer import tqdm import torch +import numpy as np + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + torch.zeros(broadcast_shape, device=timesteps.device) class GLIDE(DiffusionPipeline): - def __init__(self, unet: UNetGLIDEModel, noise_scheduler): + def __init__( + self, + unet: UNetGLIDEModel, + noise_scheduler: ClassifierFreeGuidanceScheduler, + text_encoder: CLIPTextModel, + tokenizer: GPT2Tokenizer + ): super().__init__() - self.register_modules(unet=unet, noise_scheduler=noise_scheduler) + self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer) - def __call__(self, generator=None, torch_device=None): + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, transformer_out) + + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = torch.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = torch.exp(model_log_variance) + + pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + if clip_denoised: + pred_xstart = pred_xstart.clamp(-1, 1) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return model_mean, model_variance, model_log_variance, pred_xstart + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def __call__(self, prompt, generator=None, torch_device=None): torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.unet.to(torch_device) + self.text_encoder.to(torch_device) + # 1. Sample gaussian noise - image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) - for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): - # i) define coefficients for time step t - clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) - clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) - image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) - clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) + image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator) - # ii) predict noise residual - with torch.no_grad(): - noise_residual = self.unet(image, t) + # 2. Encode tokens + # an empty input is needed to guide the model away from ( + inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") + transformer_out = self.text_encoder(**inputs).last_hidden_state - # iii) compute predicted image from residual - # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison - pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual - pred_mean = torch.clamp(pred_mean, -1, 1) - prev_image = clip_coeff * pred_mean + image_coeff * image - - # iv) sample variance - prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) - - # v) sample x_{t-1} ~ N(prev_image, prev_variance) - sampled_prev_image = prev_image + prev_variance - image = sampled_prev_image + num_timesteps = len(self.noise_scheduler) + for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): + t = torch.tensor([i] * image.shape[0], device=torch_device) + mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t) + noise = self.noise_scheduler.sample_noise(image.shape) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) + ) # no noise when t == 0 + image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise return image diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index 23cd4e103e..4d6d8e2da7 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -1,16 +1,11 @@ import torch -from .modeling_glide import GLIDE -from diffusers import UNetGLIDEModel, GaussianDDPMScheduler +from modeling_glide import GLIDE generator = torch.Generator() generator = generator.manual_seed(0) # 1. Load models - -scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") -model = UNetGLIDEModel.from_pretrained("fusing/glide-base") - -pipeline = GLIDE(model, scheduler) +pipeline = GLIDE.from_pretrained("fusing/glide-base") img = pipeline(generator) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3ce4142f65..1419140297 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,5 +7,7 @@ __version__ = "0.0.1" from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import UNetGLIDEModel +from .models.clip_text_transformer import CLIPTextModel from .pipeline_utils import DiffusionPipeline from .schedulers.gaussian_ddpm import GaussianDDPMScheduler +from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 85f1cc03f6..964c0200d6 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,3 +18,4 @@ from .unet import UNetModel from .unet_glide import UNetGLIDEModel +from .clip_text_transformer import CLIPTextModel diff --git a/models/vision/glide/modelling_text_encoder.py b/src/diffusers/models/clip_text_transformer.py similarity index 100% rename from models/vision/glide/modelling_text_encoder.py rename to src/diffusers/models/clip_text_transformer.py diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 363d01137f..4b5cc971fc 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + #self.dtype = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL + emb = emb + transformer_proj.to(emb) + h = x.type(self.dtype) for module in self.input_blocks: - h = module(h, emb) + h = module(h, emb, transformer_out) hs.append(h) - h = self.middle_block(h, emb) + h = self.middle_block(h, emb, transformer_out) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) + h = module(h, emb, transformer_out) h = h.type(x.dtype) return self.out(h) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 60ece225ab..dfc7f6d681 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -35,10 +35,12 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], + "CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers "GaussianDDPMScheduler": ["save_config", "from_config"], + "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], }, "transformers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], + "GPT2Tokenizer": ["save_pretrained", "from_pretrained"], }, } diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 81d9601a1b..82084c6c41 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -17,3 +17,4 @@ # limitations under the License. from .gaussian_ddpm import GaussianDDPMScheduler +from .classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py new file mode 100644 index 0000000000..17222c17b9 --- /dev/null +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -0,0 +1,102 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import math +from torch import nn +import numpy as np + +from ..configuration_utils import ConfigMixin + + +SAMPLING_CONFIG_NAME = "scheduler_config.json" + + +def linear_beta_schedule(timesteps, beta_start, beta_end): + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float64) + + +class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): + + config_name = SAMPLING_CONFIG_NAME + + def __init__( + self, + timesteps=1000, + beta_schedule="squaredcos_cap_v2", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_schedule=beta_schedule, + ) + self.num_timesteps = int(timesteps) + + if beta_schedule == "squaredcos_cap_v2": + # GLIDE cosine schedule + betas = betas_for_alpha_bar( + timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def sample_noise(self, shape, device, generator=None): + # always sample on CPU to be deterministic + return torch.randn(shape, generator=generator).to(device) + + def __len__(self): + return self.num_timesteps