1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

end-to-end glide pipeline with DDIM scheduler for upscaling

This commit is contained in:
anton-l
2022-06-09 10:53:53 +02:00
parent ff89f80869
commit dc6324d44b
8 changed files with 238 additions and 87 deletions

View File

@@ -1,7 +1,7 @@
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
@@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict = torch.load("upsample.pt", map_location="cpu")
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
@@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel(
resblock_updown=True,
)
superres_model.load_state_dict(state_dict)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=scheduler)
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide.save_pretrained("./glide-base")

View File

@@ -18,7 +18,7 @@ import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline):
def __init__(
self,
unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_unet: GLIDETextToImageUNetModel,
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
):
super().__init__()
self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
)
def q_posterior_mean_variance(self, x_start, x_t, t):
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, transformer_out)
if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps):
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.text_unet.to(torch_device)
self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function
guidance_scale = 3.0
def model_fn(x_t, ts, transformer_out, **kwargs):
def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.unet(combined, ts, transformer_out, **kwargs)
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
image = self.text_noise_scheduler.sample_noise(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens
@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
num_timesteps = len(self.noise_scheduler)
# 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out)
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
image = image[0].permute(1, 2, 0)
return image

View File

@@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True
generator = torch.Generator()
generator = generator.manual_seed(0)
# 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline("a pencil sketch of a corgi", generator)

View File

@@ -13,3 +13,4 @@ from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler

View File

@@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
transformer_dim=None,
):
super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
@@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, y=None):
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
@@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
@@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = []
@@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, x, timesteps, low_res=None, **kwargs):
def forward(self, x, timesteps, low_res=None):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)

View File

@@ -39,6 +39,7 @@ LOADABLE_CLASSES = {
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
},
"transformers": {
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],

View File

@@ -18,3 +18,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
import numpy as np
from torch import nn
from ..configuration_utils import ConfigMixin
@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
variance_type="fixed_small",
variance_type="fixed_large"
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
variance_type=variance_type,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.num_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
@@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
return self.num_timesteps