diff --git a/README.md b/README.md index 2c608de6ac..e76207044c 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th Example: ```python -from diffusers import UNetModel, GaussianDiffusion +from diffusers import UNetModel, GaussianDDPMScheduler import torch # 1. Load model @@ -40,7 +40,7 @@ time_step = torch.tensor([10]) image = unet(dummy_noise, time_step) # 3. Load sampler -sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") +sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy") # 4. Sample image from sampler passing the model image = sampler.sample(model, batch_size=1) @@ -54,12 +54,12 @@ print(image) Example: ```python -from diffusers import UNetModel, GaussianDiffusion +from diffusers import UNetModel, GaussianDDPMScheduler from modeling_ddpm import DDPM import tempfile unet = UNetModel.from_pretrained("fusing/ddpm_dummy") -sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") +sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy") # compose Diffusion Pipeline ddpm = DDPM(unet, sampler) diff --git a/examples/sample_loop.py b/examples/sample_loop.py index 8255140ce9..d8134a6bf3 100755 --- a/examples/sample_loop.py +++ b/examples/sample_loop.py @@ -1,99 +1,157 @@ #!/usr/bin/env python3 -from diffusers import UNetModel, GaussianDiffusion +from diffusers import UNetModel, GaussianDDPMScheduler import torch import torch.nn.functional as F +import numpy as np +import PIL.Image +import tqdm -unet = UNetModel.from_pretrained("fusing/ddpm_dummy") -diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy") - +#torch_device = "cuda" +# +#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church") +#unet.to(torch_device) +# +#TIME_STEPS = 10 +# +#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS) +# +#diffusion_config = { +# "beta_start": 0.0001, +# "beta_end": 0.02, +# "num_diffusion_timesteps": TIME_STEPS, +#} +# # 2. Do one denoising step with model -batch_size, num_channels, height, width = 1, 3, 32, 32 -dummy_noise = torch.ones((batch_size, num_channels, height, width)) - - -TIME_STEPS = 10 - - +#batch_size, num_channels, height, width = 1, 3, 256, 256 +# +#torch.manual_seed(0) +#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda") +# +# # Helper -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) +#def noise_like(shape, device, repeat=False): +# def repeat_noise(): +# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) +# +# def noise(): +# return torch.randn(shape, device=device) +# +# return repeat_noise() if repeat else noise() +# +# +#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64) +#betas = torch.tensor(betas, device=torch_device) +#alphas = 1.0 - betas +# +#alphas_cumprod = torch.cumprod(alphas, axis=0) +#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) +# +#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) +#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) +# +#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) +#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) +# +# +#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) +#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) +# +# +#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod) +#coeff = 1 / torch.sqrt(alphas) -def noise_like(shape, device, repeat=False): - def repeat_noise(): - return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) +def real_fn(): + # Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf + # 1: x_t ~ N(0,1) + x_t = noise_image + # 2: for t = T, ...., 1 do + for i in reversed(range(TIME_STEPS)): + t = torch.tensor([i]).to(torch_device) + # 3: z ~ N(0, 1) + noise = noise_like(x_t.shape, torch_device) - def noise(): - return torch.randn(shape, device=device) + # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz + # ------------------------- MODEL ------------------------------------# + with torch.no_grad(): + pred_noise = unet(x_t, t) # pred epsilon_theta - return repeat_noise() if repeat else noise() + # pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise + # pred_x.clamp_(-1.0, 1.0) + # pred mean + # posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t + # --------------------------------------------------------------------# + + posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise) + + # ------------------------- Variance Scheduler -----------------------# + # pred variance + posterior_log_variance = posterior_log_variance_clipped[t] + + b, *_, device = *x_t.shape, x_t.device + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1))) + posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp() + # --------------------------------------------------------------------# + + x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32) + x_t = x_t_1 + + print(x_t.abs().sum()) -# Schedule -def cosine_beta_schedule(timesteps, s=0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - x = torch.linspace(0, timesteps, steps, dtype=torch.float64) - alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) +def post_process_to_image(x_t): + image = x_t.cpu().permute(0, 2, 3, 1) + image = (image + 1.0) * 127.5 + image = image.numpy().astype(np.uint8) + + return PIL.Image.fromarray(image[0]) -betas = cosine_beta_schedule(TIME_STEPS) -alphas = 1.0 - betas -alphas_cumprod = torch.cumprod(alphas, axis=0) -alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) +from pytorch_diffusion import Diffusion -posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) -posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) - -posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) -posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) +#diffusion = Diffusion.from_pretrained("lsun_church") +#samples = diffusion.denoise(1) +# +#image = post_process_to_image(samples) +#image.save("check.png") +#import ipdb; ipdb.set_trace() -sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) -sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) +device = "cuda" +scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10) + +import ipdb; ipdb.set_trace() + +model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device) + torch.manual_seed(0) +next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device) -# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf -# 1: x_t ~ N(0,1) -x_t = dummy_noise -# 2: for t = T, ...., 1 do -for i in reversed(range(TIME_STEPS)): - t = torch.tensor([i]) - # 3: z ~ N(0, 1) - noise = noise_like(x_t.shape, "cpu") +for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)): + # define coefficients for time step t + clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) + clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) + image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) + clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) - # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz - # ------------------------- MODEL ------------------------------------# - pred_noise = unet(x_t, t) # pred epsilon_theta - pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise - pred_x.clamp_(-1.0, 1.0) - # pred mean - posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t - # --------------------------------------------------------------------# + # predict noise residual + with torch.no_grad(): + noise_residual = model(next_image, t) - # ------------------------- Variance Scheduler -----------------------# - # pred variance - posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape) - b, *_, device = *x_t.shape, x_t.device - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1))) - posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp() - # --------------------------------------------------------------------# + # compute prev image from noise + pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual + pred_mean = torch.clamp(pred_mean, -1, 1) + image = clip_coeff * pred_mean + image_coeff * next_image - x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32) + # sample variance + variance = scheduler.sample_variance(t, image.shape, device=device) - # FOR PATRICK TO VERIFY: make sure manual loop is equal to function - # --------------------------------------------------------------------# - x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise) - assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3 - # --------------------------------------------------------------------# + # sample previous image + sampled_image = image + variance - x_t = x_t_1 + next_image = sampled_image + + +image = post_process_to_image(next_image) +image.save("example_new.png") diff --git a/models/vision/ddpm/example.py b/models/vision/ddpm/example.py index ec339c4cdf..fb1b20477a 100755 --- a/models/vision/ddpm/example.py +++ b/models/vision/ddpm/example.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 -from diffusers import UNetModel, GaussianDiffusion -from modeling_ddpm import DDPM import tempfile +from diffusers import GaussianDDPMScheduler, UNetModel +from modeling_ddpm import DDPM + + unet = UNetModel.from_pretrained("fusing/ddpm_dummy") -sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") +sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy") # compose Diffusion Pipeline ddpm = DDPM(unet, sampler) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 3525ec30c0..ccd29454bd 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline class DDPM(DiffusionPipeline): - def __init__(self, unet, gaussian_sampler): super().__init__(unet=unet, gaussian_sampler=gaussian_sampler) diff --git a/models/vision/ddpm/run_ddpm.py b/models/vision/ddpm/run_ddpm.py index 6bf131f7b5..88de931381 100755 --- a/models/vision/ddpm/run_ddpm.py +++ b/models/vision/ddpm/run_ddpm.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import torch -from diffusers import GaussianDiffusion, UNetModel +from diffusers import GaussianDDPMScheduler, UNetModel model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) -diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 +diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 loss = diffusion(training_images) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 135d49c83e..cfb2da04a7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,8 +4,7 @@ __version__ = "0.0.1" -from .models.unet import UNetModel -from .samplers.gaussian import GaussianDiffusion - -from .pipeline_utils import DiffusionPipeline from .modeling_utils import PreTrainedModel +from .models.unet import UNetModel +from .pipeline_utils import DiffusionPipeline +from .schedulers.gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 164156437e..f287922eb0 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -17,10 +17,10 @@ import copy +import inspect import json import os import re -import inspect from typing import Any, Dict, Tuple, Union from requests import HTTPError @@ -186,6 +186,11 @@ class Config: expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") + for key in expected_keys: + if key in kwargs: + # overwrite key + config_dict[key] = kwargs.pop(key) + passed_keys = set(config_dict.keys()) unused_kwargs = kwargs @@ -194,17 +199,16 @@ class Config: if len(expected_keys - passed_keys) > 0: logger.warn( - f"{expected_keys - passed_keys} was not found in config. " - f"Values will be initialized to default values." + f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) return config_dict, unused_kwargs @classmethod - def from_config( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs - ): - config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + config_dict, unused_kwargs = cls.get_config_dict( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + ) model = cls(**config_dict) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index d346fb7400..8c00ceb75c 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -24,6 +24,7 @@ from requests import HTTPError # CHANGE to diffusers.utils from transformers.utils import ( + CONFIG_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, EntryNotFoundError, RepositoryNotFoundError, @@ -33,7 +34,6 @@ from transformers.utils import ( is_offline_mode, is_remote_url, logging, - CONFIG_NAME, ) diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 6b8069d1cb..97ffef648e 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -17,35 +17,362 @@ import copy import math -from functools import partial -from inspect import isfunction from pathlib import Path import torch -from torch import einsum, nn +from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data -from einops import rearrange -from torchvision import utils, transforms +from torchvision import transforms, utils +from PIL import Image from tqdm import tqdm from ..configuration_utils import Config from ..modeling_utils import PreTrainedModel -from PIL import Image - -# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -def exists(x): - return x is not None +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class UNetModel(PreTrainedModel, Config): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 1, 2, 2, 4, 4), + num_res_blocks=2, + attn_resolutions=(16,), + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=256, + ): + super().__init__() + self.register( + ch=ch, + out_ch=out_ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=in_channels, + resolution=resolution, + ) + ch_mult = tuple(ch_mult) + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t): + assert x.shape[2] == x.shape[3] == self.resolution + + if not torch.is_tensor(t): + t = torch.tensor([t], dtype=torch.long, device=x.device) + + # timestep embedding + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +# dataset classes + +class Dataset(data.Dataset): + def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + self.transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(image_size), + transforms.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + + +# trainer class +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new def cycle(dl): @@ -63,345 +390,7 @@ def num_to_groups(num, divisor): return arr -def normalize_to_neg_one_to_one(img): - return img * 2 - 1 - - -def unnormalize_to_zero_to_one(t): - return (t + 1) * 0.5 - - -# small helper modules - - -class EMA: - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_model_average(self, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = self.update_average(old_weight, up_weight) - - def update_average(self, old, new): - if old is None: - return new - return old * self.beta + (1 - self.beta) * new - - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x, *args, **kwargs): - return self.fn(x, *args, **kwargs) + x - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -def Upsample(dim): - return nn.ConvTranspose2d(dim, dim, 4, 2, 1) - - -def Downsample(dim): - return nn.Conv2d(dim, dim, 4, 2, 1) - - -class LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.eps = eps - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) - - def forward(self, x): - var = torch.var(x, dim=1, unbiased=False, keepdim=True) - mean = torch.mean(x, dim=1, keepdim=True) - return (x - mean) / (var + self.eps).sqrt() * self.g + self.b - - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.fn = fn - self.norm = LayerNorm(dim) - - def forward(self, x): - x = self.norm(x) - return self.fn(x) - - -# building block modules - - -class Block(nn.Module): - def __init__(self, dim, dim_out, groups=8): - super().__init__() - self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) - self.norm = nn.GroupNorm(groups, dim_out) - self.act = nn.SiLU() - - def forward(self, x, scale_shift=None): - x = self.proj(x) - x = self.norm(x) - - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - - x = self.act(x) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): - super().__init__() - self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None - - self.block1 = Block(dim, dim_out, groups=groups) - self.block2 = Block(dim_out, dim_out, groups=groups) - self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - - def forward(self, x, time_emb=None): - - scale_shift = None - if exists(self.mlp) and exists(time_emb): - time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, "b c -> b c 1 1") - scale_shift = time_emb.chunk(2, dim=1) - - h = self.block1(x, scale_shift=scale_shift) - - h = self.block2(h) - return h + self.res_conv(x) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.scale = dim_head**-0.5 - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - - self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim=1) - q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) - - q = q.softmax(dim=-2) - k = k.softmax(dim=-1) - - q = q * self.scale - context = torch.einsum("b h d n, b h e n -> b h d e", k, v) - - out = torch.einsum("b h d e, b h d n -> b h e n", context, q) - out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) - return self.to_out(out) - - -class Attention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.scale = dim_head**-0.5 - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim=1) - q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) - q = q * self.scale - - sim = einsum("b h d i, b h d j -> b h i j", q, k) - sim = sim - sim.amax(dim=-1, keepdim=True).detach() - attn = sim.softmax(dim=-1) - - out = einsum("b h i j, b h d j -> b h i d", attn, v) - out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) - return self.to_out(out) - - -class UNetModel(PreTrainedModel, Config): - - def __init__( - self, - dim=64, - dim_mults=(1, 2, 4, 8), - init_dim=None, - out_dim=None, - channels=3, - with_time_emb=True, - resnet_block_groups=8, - learned_variance=False, - ): - super().__init__() - self.register( - dim=dim, - dim_mults=dim_mults, - init_dim=init_dim, - out_dim=out_dim, - channels=channels, - with_time_emb=with_time_emb, - resnet_block_groups=resnet_block_groups, - learned_variance=learned_variance, - ) - init_dim = None - out_dim = None - channels = 3 - with_time_emb = True - resnet_block_groups = 8 - learned_variance = False - - # determine dimensions - - dim_mults = dim_mults - dim = dim - self.channels = channels - - init_dim = default(init_dim, dim // 3 * 2) - self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - block_klass = partial(ResnetBlock, groups=resnet_block_groups) - - # time embeddings - - if with_time_emb: - time_dim = dim * 4 - self.time_mlp = nn.Sequential( - SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) - ) - else: - time_dim = None - self.time_mlp = None - - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.downs.append( - nn.ModuleList( - [ - block_klass(dim_in, dim_out, time_emb_dim=time_dim), - block_klass(dim_out, dim_out, time_emb_dim=time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else nn.Identity(), - ] - ) - ) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (num_resolutions - 1) - - self.ups.append( - nn.ModuleList( - [ - block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), - block_klass(dim_in, dim_in, time_emb_dim=time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Upsample(dim_in) if not is_last else nn.Identity(), - ] - ) - ) - - default_out_dim = channels * (1 if not learned_variance else 2) - self.out_dim = default(out_dim, default_out_dim) - - self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1)) - - def forward(self, x, time): - x = self.init_conv(x) - - t = self.time_mlp(time) if exists(self.time_mlp) else None - - h = [] - - for block1, block2, attn, downsample in self.downs: - x = block1(x, t) - x = block2(x, t) - x = attn(x) - h.append(x) - x = downsample(x) - - x = self.mid_block1(x, t) - x = self.mid_attn(x) - x = self.mid_block2(x, t) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim=1) - x = block1(x, t) - x = block2(x, t) - x = attn(x) - x = upsample(x) - - return self.final_conv(x) - - -# dataset classes - - -class Dataset(data.Dataset): - def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): - super().__init__() - self.folder = folder - self.image_size = image_size - self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] - - self.transform = transforms.Compose( - [ - transforms.Resize(image_size), - transforms.RandomHorizontalFlip(), - transforms.CenterCrop(image_size), - transforms.ToTensor(), - ] - ) - - def __len__(self): - return len(self.paths) - - def __getitem__(self, index): - path = self.paths[index] - img = Image.open(path) - return self.transform(img) - - -# trainer class - - class Trainer(object): - def __init__( self, diffusion_model, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 2e4c88b785..5740f5f0ca 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os from typing import Optional, Union -import importlib - -from .configuration_utils import Config # CHANGE to diffusers.utils from transformers.utils import logging +from .configuration_utils import Config + INDEX_FILE = "diffusion_model.pt" @@ -33,7 +33,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "GaussianDiffusion": ["save_config", "from_config"], + "GaussianDDPMScheduler": ["save_config", "from_config"], }, "transformers": { "PreTrainedModel": ["save_pretrained", "from_pretrained"], diff --git a/src/diffusers/samplers/gaussian.py b/src/diffusers/samplers/gaussian.py deleted file mode 100644 index f1fb9eeb49..0000000000 --- a/src/diffusers/samplers/gaussian.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch.nn.functional as F -from torch import nn -from inspect import isfunction -from tqdm import tqdm - -from ..configuration_utils import Config -SAMPLING_CONFIG_NAME = "sampler_config.json" - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def cycle(dl): - while True: - for data_dl in dl: - yield data_dl - - -def num_to_groups(num, divisor): - groups = num // divisor - remainder = num % divisor - arr = [divisor] * groups - if remainder > 0: - arr.append(remainder) - return arr - - -def normalize_to_neg_one_to_one(img): - return img * 2 - 1 - - -def unnormalize_to_zero_to_one(t): - return (t + 1) * 0.5 - - -# small helper modules - - -class EMA: - def __init__(self, beta): - super().__init__() - self.beta = beta - - def update_model_average(self, ma_model, current_model): - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = self.update_average(old_weight, up_weight) - - def update_average(self, old, new): - if old is None: - return new - return old * self.beta + (1 - self.beta) * new - - -# gaussian diffusion trainer class - - -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def noise_like(shape, device, repeat=False): - def repeat_noise(): - return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - - def noise(): - return torch.randn(shape, device=device) - - return repeat_noise() if repeat else noise() - - -def linear_beta_schedule(timesteps): - scale = 1000 / timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) - - -def cosine_beta_schedule(timesteps, s=0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - x = torch.linspace(0, timesteps, steps, dtype=torch.float64) - alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - - -class GaussianDiffusion(nn.Module, Config): - - config_name = SAMPLING_CONFIG_NAME - - def __init__( - self, - image_size, - channels=3, - timesteps=1000, - loss_type="l1", - objective="pred_noise", - beta_schedule="cosine", - ): - super().__init__() - self.register( - image_size=image_size, - channels=channels, - timesteps=timesteps, - loss_type=loss_type, - objective=objective, - beta_schedule=beta_schedule, - ) - - self.channels = channels - self.image_size = image_size - self.objective = objective - - if beta_schedule == "linear": - betas = linear_beta_schedule(timesteps) - elif beta_schedule == "cosine": - betas = cosine_beta_schedule(timesteps) - else: - raise ValueError(f"unknown beta schedule {beta_schedule}") - - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - # helper function to register buffer from float64 to float32 - - def register_buffer(name, val): - self.register_buffer(name, val.to(torch.float32)) - - register_buffer("betas", betas) - register_buffer("alphas_cumprod", alphas_cumprod) - register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) - register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) - register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) - register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) - register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - register_buffer("posterior_variance", posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20))) - register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) - register_buffer( - "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) - ) - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, model, x, t, clip_denoised: bool): - model_output = model(x, t) - - if self.objective == "pred_noise": - x_start = self.predict_start_from_noise(x, t=t, noise=model_output) - elif self.objective == "pred_x0": - x_start = model_output - else: - raise ValueError(f"unknown objective {self.objective}") - - if clip_denoised: - x_start.clamp_(-1.0, 1.0) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, model, x, t, noise=None, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised) - if noise is None: - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - return result - - @torch.no_grad() - def p_sample_loop(self, model, shape): - device = self.betas.device - - b = shape[0] - img = torch.randn(shape, device=device) - - for i in tqdm( - reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps - ): - img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long)) - - img = unnormalize_to_zero_to_one(img) - return img - - @torch.no_grad() - def sample(self, model, batch_size=16): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop(model, (batch_size, channels, image_size, image_size)) - - @torch.no_grad() - def interpolate(self, model, x1, x2, t=None, lam=0.5): - b, *_, device = *x1.shape, x1.device - t = default(t, self.num_timesteps - 1) - - assert x1.shape == x2.shape - - t_batched = torch.stack([torch.tensor(t, device=device)] * b) - xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) - - img = (1 - lam) * xt1 + lam * xt2 - for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): - img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long)) - - return img - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - @property - def loss_fn(self): - if self.loss_type == "l1": - return F.l1_loss - elif self.loss_type == "l2": - return F.mse_loss - else: - raise ValueError(f"invalid loss type {self.loss_type}") - - def p_losses(self, model, x_start, t, noise=None): - b, c, h, w = x_start.shape - noise = default(noise, lambda: torch.randn_like(x_start)) - - x = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = model(x, t) - - if self.objective == "pred_noise": - target = noise - elif self.objective == "pred_x0": - target = x_start - else: - raise ValueError(f"unknown objective {self.objective}") - - loss = self.loss_fn(model_out, target) - return loss - - def forward(self, model, img, *args, **kwargs): - b, _, h, w, device, img_size, = ( - *img.shape, - img.device, - self.image_size, - ) - assert h == img_size and w == img_size, f"height and width of image must be {img_size}" - t = torch.randint(0, self.num_timesteps, (b,), device=device).long() - - img = normalize_to_neg_one_to_one(img) - return self.p_losses(model, img, t, *args, **kwargs) diff --git a/src/diffusers/samplers/__init__.py b/src/diffusers/schedulers/__init__.py similarity index 94% rename from src/diffusers/samplers/__init__.py rename to src/diffusers/schedulers/__init__.py index 76aa8aab0c..81d9601a1b 100644 --- a/src/diffusers/samplers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .gaussian import GaussianDiffusion +from .gaussian_ddpm import GaussianDDPMScheduler diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py new file mode 100644 index 0000000000..b0970b458d --- /dev/null +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -0,0 +1,98 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn + +from ..configuration_utils import Config + + +SAMPLING_CONFIG_NAME = "scheduler_config.json" + + +def linear_beta_schedule(timesteps, beta_start, beta_end): + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +class GaussianDDPMScheduler(nn.Module, Config): + + config_name = SAMPLING_CONFIG_NAME + + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + variance_type="fixed_small", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + beta_schedule=beta_schedule, + variance_type=variance_type, + ) + self.num_timesteps = int(timesteps) + + if beta_schedule == "linear": + betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + + variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + if variance_type == "fixed_small": + log_variance = torch.log(variance.clamp(min=1e-20)) + elif variance_type == "fixed_large": + log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) + + self.register_buffer("betas", betas.to(torch.float32)) + self.register_buffer("alphas", alphas.to(torch.float32)) + self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) + + self.register_buffer("log_variance", log_variance.to(torch.float32)) + + def get_alpha(self, time_step): + return self.alphas[time_step] + + def get_beta(self, time_step): + return self.betas[time_step] + + def get_alpha_prod(self, time_step): + if time_step < 0: + return torch.tensor(1.0) + return self.alphas_cumprod[time_step] + + def sample_variance(self, time_step, shape, device, generator=None): + variance = self.log_variance[time_step] + nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1) + + noise = self.sample_noise(shape, device=device, generator=generator) + + sampled_variance = nonzero_mask * (0.5 * variance).exp() + sampled_variance = sampled_variance * noise + + return sampled_variance + + def sample_noise(self, shape, device, generator=None): + # always sample on CPU to be deterministic + return torch.randn(shape, generator=generator).to(device) + + def __len__(self): + return self.num_timesteps diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 1233980f39..4655c96749 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -16,13 +16,45 @@ import random import tempfile import unittest +import os +from distutils.util import strtobool import torch -from diffusers import GaussianDiffusion, UNetModel +from diffusers import GaussianDDPMScheduler, UNetModel global_rng = random.Random() +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) def floats_tensor(shape, scale=1.0, rng=None, name=None): @@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase): return (noise, time_step) def test_from_pretrained_save_pretrained(self): - model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2) + model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase): - @property - def dummy_model(self): - return UNetModel.from_pretrained("fusing/ddpm_dummy") + @slow + def test_sample(self): + generator = torch.Generator() + generator = generator.manual_seed(6694729458485568) - def test_from_pretrained_save_pretrained(self): - sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1") + # 1. Load models + scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") + model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) - with tempfile.TemporaryDirectory() as tmpdirname: - sampler.save_config(tmpdirname) - new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False) + # 2. Sample gaussian noise + image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) - model = self.dummy_model + # 3. Denoise + for t in reversed(range(len(scheduler))): + # i) define coefficients for time step t + clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) + clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) + image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) + clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) + # ii) predict noise residual + with torch.no_grad(): + noise_residual = model(image, t) + + # iii) compute predicted image from residual + # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison + pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual + pred_mean = torch.clamp(pred_mean, -1, 1) + prev_image = clip_coeff * pred_mean + image_coeff * image + + # iv) sample variance + prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) + + # v) sample x_{t-1} ~ N(prev_image, prev_variance) + sampled_prev_image = prev_image + prev_variance + image = sampled_prev_image + + # Note: The better test is to simply check with the following lines of code that the image is sensible + # import PIL + # import numpy as np + # image_processed = image.cpu().permute(0, 2, 3, 1) + # image_processed = (image_processed + 1.0) * 127.5 + # image_processed = image_processed.numpy().astype(np.uint8) + # image_pil = PIL.Image.fromarray(image_processed[0]) + # image_pil.save("test.png") + + assert image.shape == (1, 3, 256, 256) + image_slice = image[0, -1, -3:, -3:].cpu() + assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3 + + def test_sample_fast(self): + # 1. Load models + generator = torch.Generator() + generator = generator.manual_seed(6694729458485568) + + scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10) + model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) + + # 2. Sample gaussian noise torch.manual_seed(0) - sampled_out = sampler.sample(model, batch_size=1) - torch.manual_seed(0) - sampled_out_new = new_sampler.sample(model, batch_size=1) + image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) - assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output" + # 3. Denoise + for t in reversed(range(len(scheduler))): + # i) define coefficients for time step t + clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) + clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) + image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) + clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) - def test_from_pretrained_hub(self): - sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") - model = self.dummy_model + # ii) predict noise residual + with torch.no_grad(): + noise_residual = model(image, t) - sampled_out = sampler.sample(model, batch_size=1) + # iii) compute predicted image from residual + # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison + pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual + pred_mean = torch.clamp(pred_mean, -1, 1) + prev_image = clip_coeff * pred_mean + image_coeff * image - assert sampled_out is not None, "Make sure output is not None" + # iv) sample variance + prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) + + # v) sample x_{t-1} ~ N(prev_image, prev_variance) + sampled_prev_image = prev_image + prev_variance + image = sampled_prev_image + + assert image.shape == (1, 3, 256, 256) + image_slice = image[0, -1, -3:, -3:].cpu() + assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3