mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
improve
This commit is contained in:
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import UNetModel, GaussianDiffusion
|
||||
from diffusers import UNetModel, GaussianDDPMScheduler
|
||||
import torch
|
||||
|
||||
# 1. Load model
|
||||
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
|
||||
image = unet(dummy_noise, time_step)
|
||||
|
||||
# 3. Load sampler
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
|
||||
|
||||
# 4. Sample image from sampler passing the model
|
||||
image = sampler.sample(model, batch_size=1)
|
||||
@@ -54,12 +54,12 @@ print(image)
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import UNetModel, GaussianDiffusion
|
||||
from diffusers import UNetModel, GaussianDDPMScheduler
|
||||
from modeling_ddpm import DDPM
|
||||
import tempfile
|
||||
|
||||
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
|
||||
|
||||
# compose Diffusion Pipeline
|
||||
ddpm = DDPM(unet, sampler)
|
||||
|
||||
@@ -1,99 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
from diffusers import UNetModel, GaussianDiffusion
|
||||
from diffusers import UNetModel, GaussianDDPMScheduler
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import tqdm
|
||||
|
||||
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
|
||||
#torch_device = "cuda"
|
||||
#
|
||||
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
|
||||
#unet.to(torch_device)
|
||||
#
|
||||
#TIME_STEPS = 10
|
||||
#
|
||||
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
|
||||
#
|
||||
#diffusion_config = {
|
||||
# "beta_start": 0.0001,
|
||||
# "beta_end": 0.02,
|
||||
# "num_diffusion_timesteps": TIME_STEPS,
|
||||
#}
|
||||
#
|
||||
# 2. Do one denoising step with model
|
||||
batch_size, num_channels, height, width = 1, 3, 32, 32
|
||||
dummy_noise = torch.ones((batch_size, num_channels, height, width))
|
||||
|
||||
|
||||
TIME_STEPS = 10
|
||||
|
||||
|
||||
#batch_size, num_channels, height, width = 1, 3, 256, 256
|
||||
#
|
||||
#torch.manual_seed(0)
|
||||
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
|
||||
#
|
||||
#
|
||||
# Helper
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
#def noise_like(shape, device, repeat=False):
|
||||
# def repeat_noise():
|
||||
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
#
|
||||
# def noise():
|
||||
# return torch.randn(shape, device=device)
|
||||
#
|
||||
# return repeat_noise() if repeat else noise()
|
||||
#
|
||||
#
|
||||
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
|
||||
#betas = torch.tensor(betas, device=torch_device)
|
||||
#alphas = 1.0 - betas
|
||||
#
|
||||
#alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
#
|
||||
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
#
|
||||
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
|
||||
#
|
||||
#
|
||||
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
||||
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
|
||||
#
|
||||
#
|
||||
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
|
||||
#coeff = 1 / torch.sqrt(alphas)
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
def real_fn():
|
||||
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
|
||||
# 1: x_t ~ N(0,1)
|
||||
x_t = noise_image
|
||||
# 2: for t = T, ...., 1 do
|
||||
for i in reversed(range(TIME_STEPS)):
|
||||
t = torch.tensor([i]).to(torch_device)
|
||||
# 3: z ~ N(0, 1)
|
||||
noise = noise_like(x_t.shape, torch_device)
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
|
||||
# ------------------------- MODEL ------------------------------------#
|
||||
with torch.no_grad():
|
||||
pred_noise = unet(x_t, t) # pred epsilon_theta
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
|
||||
# pred_x.clamp_(-1.0, 1.0)
|
||||
# pred mean
|
||||
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
|
||||
# --------------------------------------------------------------------#
|
||||
|
||||
posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise)
|
||||
|
||||
# ------------------------- Variance Scheduler -----------------------#
|
||||
# pred variance
|
||||
posterior_log_variance = posterior_log_variance_clipped[t]
|
||||
|
||||
b, *_, device = *x_t.shape, x_t.device
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
|
||||
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
|
||||
# --------------------------------------------------------------------#
|
||||
|
||||
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
|
||||
x_t = x_t_1
|
||||
|
||||
print(x_t.abs().sum())
|
||||
|
||||
|
||||
# Schedule
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
def post_process_to_image(x_t):
|
||||
image = x_t.cpu().permute(0, 2, 3, 1)
|
||||
image = (image + 1.0) * 127.5
|
||||
image = image.numpy().astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image[0])
|
||||
|
||||
|
||||
betas = cosine_beta_schedule(TIME_STEPS)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
from pytorch_diffusion import Diffusion
|
||||
|
||||
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
|
||||
#diffusion = Diffusion.from_pretrained("lsun_church")
|
||||
#samples = diffusion.denoise(1)
|
||||
#
|
||||
#image = post_process_to_image(samples)
|
||||
#image.save("check.png")
|
||||
#import ipdb; ipdb.set_trace()
|
||||
|
||||
|
||||
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
||||
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
|
||||
device = "cuda"
|
||||
scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10)
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device)
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device)
|
||||
|
||||
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
|
||||
# 1: x_t ~ N(0,1)
|
||||
x_t = dummy_noise
|
||||
# 2: for t = T, ...., 1 do
|
||||
for i in reversed(range(TIME_STEPS)):
|
||||
t = torch.tensor([i])
|
||||
# 3: z ~ N(0, 1)
|
||||
noise = noise_like(x_t.shape, "cpu")
|
||||
for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)):
|
||||
# define coefficients for time step t
|
||||
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
|
||||
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
|
||||
# ------------------------- MODEL ------------------------------------#
|
||||
pred_noise = unet(x_t, t) # pred epsilon_theta
|
||||
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
|
||||
pred_x.clamp_(-1.0, 1.0)
|
||||
# pred mean
|
||||
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
# --------------------------------------------------------------------#
|
||||
# predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = model(next_image, t)
|
||||
|
||||
# ------------------------- Variance Scheduler -----------------------#
|
||||
# pred variance
|
||||
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
|
||||
b, *_, device = *x_t.shape, x_t.device
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
|
||||
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
|
||||
# --------------------------------------------------------------------#
|
||||
# compute prev image from noise
|
||||
pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
image = clip_coeff * pred_mean + image_coeff * next_image
|
||||
|
||||
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
|
||||
# sample variance
|
||||
variance = scheduler.sample_variance(t, image.shape, device=device)
|
||||
|
||||
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
|
||||
# --------------------------------------------------------------------#
|
||||
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
|
||||
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
|
||||
# --------------------------------------------------------------------#
|
||||
# sample previous image
|
||||
sampled_image = image + variance
|
||||
|
||||
x_t = x_t_1
|
||||
next_image = sampled_image
|
||||
|
||||
|
||||
image = post_process_to_image(next_image)
|
||||
image.save("example_new.png")
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
from diffusers import UNetModel, GaussianDiffusion
|
||||
from modeling_ddpm import DDPM
|
||||
import tempfile
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from modeling_ddpm import DDPM
|
||||
|
||||
|
||||
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
|
||||
|
||||
# compose Diffusion Pipeline
|
||||
ddpm = DDPM(unet, sampler)
|
||||
|
||||
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, gaussian_sampler):
|
||||
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDiffusion, UNetModel
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
|
||||
|
||||
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
|
||||
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
|
||||
loss = diffusion(training_images)
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from .models.unet import UNetModel
|
||||
from .samplers.gaussian import GaussianDiffusion
|
||||
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.unet import UNetModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
|
||||
@@ -17,10 +17,10 @@
|
||||
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
@@ -186,6 +186,11 @@ class Config:
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
|
||||
for key in expected_keys:
|
||||
if key in kwargs:
|
||||
# overwrite key
|
||||
config_dict[key] = kwargs.pop(key)
|
||||
|
||||
passed_keys = set(config_dict.keys())
|
||||
|
||||
unused_kwargs = kwargs
|
||||
@@ -194,17 +199,16 @@ class Config:
|
||||
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.warn(
|
||||
f"{expected_keys - passed_keys} was not found in config. "
|
||||
f"Values will be initialized to default values."
|
||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||
)
|
||||
|
||||
return config_dict, unused_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
):
|
||||
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
config_dict, unused_kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
model = cls(**config_dict)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from requests import HTTPError
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
@@ -33,7 +34,6 @@ from transformers.utils import (
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
CONFIG_NAME,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,35 +17,362 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import einsum, nn
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from einops import rearrange
|
||||
from torchvision import utils, transforms
|
||||
from torchvision import transforms, utils
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import Config
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from PIL import Image
|
||||
|
||||
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class UNetModel(PreTrainedModel, Config):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 1, 2, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16,),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=256,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
)
|
||||
ch_mult = tuple(ch_mult)
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t):
|
||||
assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if not torch.is_tensor(t):
|
||||
t = torch.tensor([t], dtype=torch.long, device=x.device)
|
||||
|
||||
# timestep embedding
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
# dataset classes
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
|
||||
# trainer class
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
@@ -63,345 +390,7 @@ def num_to_groups(num, divisor):
|
||||
return arr
|
||||
|
||||
|
||||
def normalize_to_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
|
||||
def unnormalize_to_zero_to_one(t):
|
||||
return (t + 1) * 0.5
|
||||
|
||||
|
||||
# small helper modules
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.fn(x, *args, **kwargs) + x
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
||||
mean = torch.mean(x, dim=1, keepdim=True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
return self.fn(x)
|
||||
|
||||
|
||||
# building block modules
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift=None):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups=groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, time_emb=None):
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.mlp) and exists(time_emb):
|
||||
time_emb = self.mlp(time_emb)
|
||||
time_emb = rearrange(time_emb, "b c -> b c 1 1")
|
||||
scale_shift = time_emb.chunk(2, dim=1)
|
||||
|
||||
h = self.block1(x, scale_shift=scale_shift)
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim))
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x).chunk(3, dim=1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)
|
||||
|
||||
q = q.softmax(dim=-2)
|
||||
k = k.softmax(dim=-1)
|
||||
|
||||
q = q * self.scale
|
||||
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
||||
|
||||
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
||||
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x).chunk(3, dim=1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
||||
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class UNetModel(PreTrainedModel, Config):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=64,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
init_dim=None,
|
||||
out_dim=None,
|
||||
channels=3,
|
||||
with_time_emb=True,
|
||||
resnet_block_groups=8,
|
||||
learned_variance=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
dim=dim,
|
||||
dim_mults=dim_mults,
|
||||
init_dim=init_dim,
|
||||
out_dim=out_dim,
|
||||
channels=channels,
|
||||
with_time_emb=with_time_emb,
|
||||
resnet_block_groups=resnet_block_groups,
|
||||
learned_variance=learned_variance,
|
||||
)
|
||||
init_dim = None
|
||||
out_dim = None
|
||||
channels = 3
|
||||
with_time_emb = True
|
||||
resnet_block_groups = 8
|
||||
learned_variance = False
|
||||
|
||||
# determine dimensions
|
||||
|
||||
dim_mults = dim_mults
|
||||
dim = dim
|
||||
self.channels = channels
|
||||
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
||||
|
||||
# time embeddings
|
||||
|
||||
if with_time_emb:
|
||||
time_dim = dim * 4
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim)
|
||||
)
|
||||
else:
|
||||
time_dim = None
|
||||
self.time_mlp = None
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
||||
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
||||
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
||||
Downsample(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
||||
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
||||
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
||||
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
||||
Upsample(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
default_out_dim = channels * (1 if not learned_variance else 2)
|
||||
self.out_dim = default(out_dim, default_out_dim)
|
||||
|
||||
self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1))
|
||||
|
||||
def forward(self, x, time):
|
||||
x = self.init_conv(x)
|
||||
|
||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||
|
||||
h = []
|
||||
|
||||
for block1, block2, attn, downsample in self.downs:
|
||||
x = block1(x, t)
|
||||
x = block2(x, t)
|
||||
x = attn(x)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for block1, block2, attn, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = block1(x, t)
|
||||
x = block2(x, t)
|
||||
x = attn(x)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
# dataset classes
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
|
||||
# trainer class
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
|
||||
@@ -14,15 +14,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
|
||||
from .configuration_utils import Config
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_utils import Config
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_model.pt"
|
||||
|
||||
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"GaussianDiffusion": ["save_config", "from_config"],
|
||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from inspect import isfunction
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import Config
|
||||
SAMPLING_CONFIG_NAME = "sampler_config.json"
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data_dl in dl:
|
||||
yield data_dl
|
||||
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
|
||||
def normalize_to_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
|
||||
def unnormalize_to_zero_to_one(t):
|
||||
return (t + 1) * 0.5
|
||||
|
||||
|
||||
# small helper modules
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
# gaussian diffusion trainer class
|
||||
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module, Config):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
channels=3,
|
||||
timesteps=1000,
|
||||
loss_type="l1",
|
||||
objective="pred_noise",
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
image_size=image_size,
|
||||
channels=channels,
|
||||
timesteps=timesteps,
|
||||
loss_type=loss_type,
|
||||
objective=objective,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
|
||||
self.channels = channels
|
||||
self.image_size = image_size
|
||||
self.objective = objective
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
else:
|
||||
raise ValueError(f"unknown beta schedule {beta_schedule}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
# helper function to register buffer from float64 to float32
|
||||
|
||||
def register_buffer(name, val):
|
||||
self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer("betas", betas)
|
||||
register_buffer("alphas_cumprod", alphas_cumprod)
|
||||
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
||||
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
||||
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
||||
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
||||
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
register_buffer("posterior_variance", posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
|
||||
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
|
||||
register_buffer(
|
||||
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, clip_denoised: bool):
|
||||
model_output = model(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
|
||||
elif self.objective == "pred_x0":
|
||||
x_start = model_output
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
if clip_denoised:
|
||||
x_start.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, model, x, t, noise=None, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
|
||||
if noise is None:
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, model, shape):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(
|
||||
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
|
||||
):
|
||||
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
img = unnormalize_to_zero_to_one(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, model, batch_size=16):
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate(self, model, x1, x2, t=None, lam=0.5):
|
||||
b, *_, device = *x1.shape, x1.device
|
||||
t = default(t, self.num_timesteps - 1)
|
||||
|
||||
assert x1.shape == x2.shape
|
||||
|
||||
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
||||
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
||||
|
||||
img = (1 - lam) * xt1 + lam * xt2
|
||||
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
|
||||
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
@property
|
||||
def loss_fn(self):
|
||||
if self.loss_type == "l1":
|
||||
return F.l1_loss
|
||||
elif self.loss_type == "l2":
|
||||
return F.mse_loss
|
||||
else:
|
||||
raise ValueError(f"invalid loss type {self.loss_type}")
|
||||
|
||||
def p_losses(self, model, x_start, t, noise=None):
|
||||
b, c, h, w = x_start.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
model_out = model(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
target = noise
|
||||
elif self.objective == "pred_x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
loss = self.loss_fn(model_out, target)
|
||||
return loss
|
||||
|
||||
def forward(self, model, img, *args, **kwargs):
|
||||
b, _, h, w, device, img_size, = (
|
||||
*img.shape,
|
||||
img.device,
|
||||
self.image_size,
|
||||
)
|
||||
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
|
||||
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
||||
|
||||
img = normalize_to_neg_one_to_one(img)
|
||||
return self.p_losses(model, img, t, *args, **kwargs)
|
||||
@@ -16,4 +16,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .gaussian import GaussianDiffusion
|
||||
from .gaussian_ddpm import GaussianDDPMScheduler
|
||||
98
src/diffusers/schedulers/gaussian_ddpm.py
Normal file
98
src/diffusers/schedulers/gaussian_ddpm.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import Config
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, Config):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
variance_type="fixed_small",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
variance_type=variance_type,
|
||||
)
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
if variance_type == "fixed_small":
|
||||
log_variance = torch.log(variance.clamp(min=1e-20))
|
||||
elif variance_type == "fixed_large":
|
||||
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
|
||||
|
||||
self.register_buffer("betas", betas.to(torch.float32))
|
||||
self.register_buffer("alphas", alphas.to(torch.float32))
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
|
||||
|
||||
self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
|
||||
def get_beta(self, time_step):
|
||||
return self.betas[time_step]
|
||||
|
||||
def get_alpha_prod(self, time_step):
|
||||
if time_step < 0:
|
||||
return torch.tensor(1.0)
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def sample_variance(self, time_step, shape, device, generator=None):
|
||||
variance = self.log_variance[time_step]
|
||||
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1)
|
||||
|
||||
noise = self.sample_noise(shape, device=device, generator=generator)
|
||||
|
||||
sampled_variance = nonzero_mask * (0.5 * variance).exp()
|
||||
sampled_variance = sampled_variance * noise
|
||||
|
||||
return sampled_variance
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_timesteps
|
||||
@@ -16,13 +16,45 @@
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDiffusion, UNetModel
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
except KeyError:
|
||||
# KEY isn't set, default to `default`.
|
||||
_value = default
|
||||
else:
|
||||
# KEY is set, convert it to True or False.
|
||||
try:
|
||||
_value = strtobool(value)
|
||||
except ValueError:
|
||||
# More values are supported, but let's keep the message simple.
|
||||
raise ValueError(f"If set, {key} must be yes or no.")
|
||||
return _value
|
||||
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""
|
||||
Decorator marking a test as slow.
|
||||
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
return (noise, time_step)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
|
||||
class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def dummy_model(self):
|
||||
return UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
@slow
|
||||
def test_sample(self):
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1")
|
||||
# 1. Load models
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sampler.save_config(tmpdirname)
|
||||
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False)
|
||||
# 2. Sample gaussian noise
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
|
||||
model = self.dummy_model
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# i) define coefficients for time step t
|
||||
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = model(image, t)
|
||||
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clip_coeff * pred_mean + image_coeff * image
|
||||
|
||||
# iv) sample variance
|
||||
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
# Note: The better test is to simply check with the following lines of code that the image is sensible
|
||||
# import PIL
|
||||
# import numpy as np
|
||||
# image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
# image_processed = (image_processed + 1.0) * 127.5
|
||||
# image_processed = image_processed.numpy().astype(np.uint8)
|
||||
# image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
# image_pil.save("test.png")
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
|
||||
|
||||
def test_sample_fast(self):
|
||||
# 1. Load models
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(6694729458485568)
|
||||
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
|
||||
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
torch.manual_seed(0)
|
||||
sampled_out = sampler.sample(model, batch_size=1)
|
||||
torch.manual_seed(0)
|
||||
sampled_out_new = new_sampler.sample(model, batch_size=1)
|
||||
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
|
||||
|
||||
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output"
|
||||
# 3. Denoise
|
||||
for t in reversed(range(len(scheduler))):
|
||||
# i) define coefficients for time step t
|
||||
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
|
||||
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
|
||||
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
|
||||
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
model = self.dummy_model
|
||||
# ii) predict noise residual
|
||||
with torch.no_grad():
|
||||
noise_residual = model(image, t)
|
||||
|
||||
sampled_out = sampler.sample(model, batch_size=1)
|
||||
# iii) compute predicted image from residual
|
||||
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
|
||||
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
||||
pred_mean = torch.clamp(pred_mean, -1, 1)
|
||||
prev_image = clip_coeff * pred_mean + image_coeff * image
|
||||
|
||||
assert sampled_out is not None, "Make sure output is not None"
|
||||
# iv) sample variance
|
||||
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
|
||||
|
||||
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
||||
sampled_prev_image = prev_image + prev_variance
|
||||
image = sampled_prev_image
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user