1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Classifier-free guidance scheduler + GLIDe pipeline

This commit is contained in:
anton-l
2022-06-08 11:47:47 +02:00
parent d1715d3385
commit 1e21f06160
11 changed files with 275 additions and 48 deletions

View File

@@ -0,0 +1,4 @@
# References
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)

View File

@@ -1,25 +1,28 @@
import argparse
import torch
from torch import nn
from transformers import CLIPTextConfig, GPT2Tokenizer
from modelling_text_encoder import CLIPTextModel
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from modeling_glide import GLIDE
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512,
intermediate_size=2048,
num_hidden_layers=16,
num_attention_heads=8,
max_position_embeddings=128,
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
tokenizer.save_pretrained("./glide-base")
#tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model
@@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#with torch.no_grad():
# outputs = model(**inputs)
model.save_pretrained("./glide-base")
#model.save_pretrained("./glide-base")
### Convert the UNet
unet_model = UNetGLIDEModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
unet_model.load_state_dict(state_dict, strict=False)
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
glide.save_pretrained("./glide-base")

View File

@@ -14,46 +14,136 @@
# limitations under the License.
from diffusers import DiffusionPipeline
from diffusers import UNetGLIDEModel
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
from transformers import GPT2Tokenizer
import tqdm
import torch
import numpy as np
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline):
def __init__(self, unet: UNetGLIDEModel, noise_scheduler):
def __init__(
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer
):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
def __call__(self, generator=None, torch_device=None):
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.text_encoder.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator)
# ii) predict noise residual
with torch.no_grad():
noise_residual = self.unet(image, t)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
transformer_out = self.text_encoder(**inputs).last_hidden_state
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
num_timesteps = len(self.noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
noise = self.noise_scheduler.sample_noise(image.shape)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
return image

View File

@@ -1,16 +1,11 @@
import torch
from .modeling_glide import GLIDE
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler
from modeling_glide import GLIDE
generator = torch.Generator()
generator = generator.manual_seed(0)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
pipeline = GLIDE(model, scheduler)
pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline(generator)

View File

@@ -7,5 +7,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel
from .models.clip_text_transformer import CLIPTextModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -18,3 +18,4 @@
from .unet import UNetModel
from .unet_glide import UNetGLIDEModel
from .clip_text_transformer import CLIPTextModel

View File

@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
#self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
@@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
transformer_proj = self.transformer_proj(transformer_out[:, -1])
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
emb = emb + transformer_proj.to(emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
h = module(h, emb, transformer_out)
hs.append(h)
h = self.middle_block(h, emb)
h = self.middle_block(h, emb, transformer_out)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = module(h, emb, transformer_out)
h = h.type(x.dtype)
return self.out(h)

View File

@@ -35,10 +35,12 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
},
"transformers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
},
}

View File

@@ -17,3 +17,4 @@
# limitations under the License.
from .gaussian_ddpm import GaussianDDPMScheduler
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -0,0 +1,102 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from torch import nn
import numpy as np
from ..configuration_utils import ConfigMixin
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float64)
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_schedule="squaredcos_cap_v2",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_schedule=beta_schedule,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
)
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps