From 4d1536bb2e141722a6a6fa53294b5a61d5f0ade7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:27 +0200 Subject: [PATCH] add vae model --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/vae.py | 669 +++++++++++++++++++++++++++++++ 3 files changed, 671 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/vae.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 99e08deea7..b62139cd0b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel +from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, VQModel from .pipeline_utils import DiffusionPipeline from .pipelines import ( BDDMPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 71e321e111..3bba113339 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet from .unet_sde_score_estimation import NCSNpp +from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py new file mode 100644 index 0000000000..2824462010 --- /dev/null +++ b/src/diffusers/models/vae.py @@ -0,0 +1,669 @@ +import math + +import numpy as np +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from .attention import AttentionBlock +from .resnet import Downsample, Upsample + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal + embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section + 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +# class AttnBlock(nn.Module): +# def __init__(self, in_channels): +# super().__init__() +# self.in_channels = in_channels + +# self.norm = Normalize(in_channels) +# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) + +# # compute attention +# b, c, h, w = q.shape +# q = q.reshape(b, c, h * w) +# q = q.permute(0, 2, 1) # b,hw,c +# k = k.reshape(b, c, h * w) # b,c,hw +# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] +# w_ = w_ * (int(c) ** (-0.5)) +# w_ = torch.nn.functional.softmax(w_, dim=2) + +# # attend to values +# v = v.reshape(b, c, h * w) +# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) +# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] +# h_ = h_.reshape(b, c, h, w) + +# h_ = self.proj_out(h_) + +# return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, use_conv=resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + n_embed, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + n_embed=n_embed, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior