mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Classifier-free guidance scheduler + GLIDe pipeline
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
# References
|
||||
|
||||
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
|
||||
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
|
||||
@@ -1,25 +1,28 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import CLIPTextConfig, GPT2Tokenizer
|
||||
from modelling_text_encoder import CLIPTextModel
|
||||
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
|
||||
from modeling_glide import GLIDE
|
||||
|
||||
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
|
||||
state_dict = torch.load("base.pt", map_location="cpu")
|
||||
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
|
||||
|
||||
### Convert the text encoder
|
||||
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=50257,
|
||||
max_position_embeddings=128,
|
||||
hidden_size=512,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=8,
|
||||
max_position_embeddings=128,
|
||||
use_padding_embeddings=True,
|
||||
)
|
||||
model = CLIPTextModel(config).eval()
|
||||
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
|
||||
tokenizer.save_pretrained("./glide-base")
|
||||
#tokenizer.save_pretrained("./glide-base")
|
||||
|
||||
hf_encoder = model.text_model
|
||||
|
||||
@@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers):
|
||||
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
||||
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
||||
|
||||
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
#with torch.no_grad():
|
||||
# outputs = model(**inputs)
|
||||
|
||||
model.save_pretrained("./glide-base")
|
||||
#model.save_pretrained("./glide-base")
|
||||
|
||||
### Convert the UNet
|
||||
|
||||
unet_model = UNetGLIDEModel(
|
||||
in_channels=3,
|
||||
model_channels=192,
|
||||
out_channels=6,
|
||||
num_res_blocks=3,
|
||||
attention_resolutions=(2, 4, 8),
|
||||
dropout=0.1,
|
||||
channel_mult=(1, 2, 3, 4),
|
||||
num_heads=1,
|
||||
num_head_channels=64,
|
||||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
)
|
||||
|
||||
unet_model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
||||
|
||||
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
|
||||
|
||||
glide.save_pretrained("./glide-base")
|
||||
@@ -14,46 +14,136 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import UNetGLIDEModel
|
||||
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
||||
|
||||
|
||||
class GLIDE(DiffusionPipeline):
|
||||
def __init__(self, unet: UNetGLIDEModel, noise_scheduler):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNetGLIDEModel,
|
||||
noise_scheduler: ClassifierFreeGuidanceScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, generator=None, torch_device=None):
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
|
||||
q(x_{t-1} | x_t, x_0)
|
||||
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
||||
:param model: the model, which takes a signal and a batch of timesteps
|
||||
as input.
|
||||
:param x: the [N x C x ...] tensor at time t.
|
||||
:param t: a 1-D Tensor of timesteps.
|
||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict with the following keys:
|
||||
- 'mean': the model mean output.
|
||||
- 'variance': the model variance output.
|
||||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
model_output = model(x, t, transformer_out)
|
||||
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = torch.exp(model_log_variance)
|
||||
|
||||
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
if clip_denoised:
|
||||
pred_xstart = pred_xstart.clamp(-1, 1)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return model_mean, model_variance, model_log_variance, pred_xstart
|
||||
|
||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def __call__(self, prompt, generator=None, torch_device=None):
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
# 1. Sample gaussian noise
|
||||
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
|
||||
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
|
||||
# i) define coefficients for time step t
|
||||
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
|
||||
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
|
||||
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator)
|
||||
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = self.unet(image, t)
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from (
|
||||
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
transformer_out = self.text_encoder(**inputs).last_hidden_state
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clip_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
num_timesteps = len(self.noise_scheduler)
|
||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
|
||||
noise = self.noise_scheduler.sample_noise(image.shape)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
return image
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import torch
|
||||
from .modeling_glide import GLIDE
|
||||
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler
|
||||
from modeling_glide import GLIDE
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
|
||||
# 1. Load models
|
||||
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
|
||||
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
|
||||
|
||||
pipeline = GLIDE(model, scheduler)
|
||||
pipeline = GLIDE.from_pretrained("fusing/glide-base")
|
||||
|
||||
img = pipeline(generator)
|
||||
|
||||
|
||||
@@ -7,5 +7,7 @@ __version__ = "0.0.1"
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import UNetGLIDEModel
|
||||
from .models.clip_text_transformer import CLIPTextModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
@@ -18,3 +18,4 @@
|
||||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import UNetGLIDEModel
|
||||
from .clip_text_transformer import CLIPTextModel
|
||||
|
||||
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
#self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
@@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
||||
transformer_proj = self.transformer_proj(transformer_out[:, -1])
|
||||
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
|
||||
|
||||
emb = emb + transformer_proj.to(emb)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
h = module(h, emb, transformer_out)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
h = self.middle_block(h, emb, transformer_out)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = module(h, emb, transformer_out)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
@@ -35,10 +35,12 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
|
||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -17,3 +17,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .gaussian_ddpm import GaussianDDPMScheduler
|
||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
102
src/diffusers/schedulers/classifier_free_guidance.py
Normal file
102
src/diffusers/schedulers/classifier_free_guidance.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float64)
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = (
|
||||
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = np.log(
|
||||
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
||||
)
|
||||
self.posterior_mean_coef1 = (
|
||||
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.posterior_mean_coef2 = (
|
||||
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_timesteps
|
||||
Reference in New Issue
Block a user