From 4d1536bb2e141722a6a6fa53294b5a61d5f0ade7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:27 +0200 Subject: [PATCH 01/12] add vae model --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/vae.py | 669 +++++++++++++++++++++++++++++++ 3 files changed, 671 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/vae.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 99e08deea7..b62139cd0b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel +from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, VQModel from .pipeline_utils import DiffusionPipeline from .pipelines import ( BDDMPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 71e321e111..3bba113339 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet from .unet_sde_score_estimation import NCSNpp +from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py new file mode 100644 index 0000000000..2824462010 --- /dev/null +++ b/src/diffusers/models/vae.py @@ -0,0 +1,669 @@ +import math + +import numpy as np +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from .attention import AttentionBlock +from .resnet import Downsample, Upsample + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal + embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section + 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +# class AttnBlock(nn.Module): +# def __init__(self, in_channels): +# super().__init__() +# self.in_channels = in_channels + +# self.norm = Normalize(in_channels) +# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) +# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x): +# h_ = x +# h_ = self.norm(h_) +# q = self.q(h_) +# k = self.k(h_) +# v = self.v(h_) + +# # compute attention +# b, c, h, w = q.shape +# q = q.reshape(b, c, h * w) +# q = q.permute(0, 2, 1) # b,hw,c +# k = k.reshape(b, c, h * w) # b,c,hw +# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] +# w_ = w_ * (int(c) ** (-0.5)) +# w_ = torch.nn.functional.softmax(w_, dim=2) + +# # attend to values +# v = v.reshape(b, c, h * w) +# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) +# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] +# h_ = h_.reshape(b, c, h, w) + +# h_ = self.proj_out(h_) + +# return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttentionBlock(block_in, overwrite_qkv=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, use_conv=resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class VQModel(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + n_embed, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + n_embed=n_embed, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(ModelMixin, ConfigMixin): + def __init__( + self, + ch, + out_ch, + num_res_blocks, + attn_resolutions, + in_channels, + resolution, + z_channels, + embed_dim, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ch_mult=(1, 2, 4, 8), + dropout=0.0, + double_z=True, + resamp_with_conv=True, + give_pre_end=False, + ): + super().__init__() + + # register all __init__ params with self.register + self.register_to_config( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + embed_dim=embed_dim, + remap=remap, + sane_index_shape=sane_index_shape, + ch_mult=ch_mult, + dropout=dropout, + double_z=double_z, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + # pass init params to Encoder + self.encoder = Encoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + double_z=double_z, + give_pre_end=give_pre_end, + ) + + # pass init params to Decoder + self.decoder = Decoder( + ch=ch, + out_ch=out_ch, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolutions, + in_channels=in_channels, + resolution=resolution, + z_channels=z_channels, + ch_mult=ch_mult, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + give_pre_end=give_pre_end, + ) + + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior From 36e1893c6f3d171bfff288ef5fa3f66fba2bc177 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:38 +0200 Subject: [PATCH 02/12] remove vae from ldm pipeline --- .../pipeline_latent_diffusion.py | 849 ------------------ 1 file changed, 849 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f7dc8ea343..6a4eb1621f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,4 +1,3 @@ -import math from typing import Optional, Tuple, Union import numpy as np @@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from ...configuration_utils import ConfigMixin -from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline @@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel): return sequence_output -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - ): - super().__init__() - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, x, t=None): - # assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits == False, "Only for interface compatible with Gumbel" - assert return_logits == False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class AutoencoderKL(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - class LatentDiffusionPipeline(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__() From 17e5b4921a2dae2b9f33b3cb603d328f94e2d36f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:48 +0200 Subject: [PATCH 03/12] remove vae from ldm uncond pipe --- .../pipeline_latent_diffusion_uncond.py | 561 ------------------ 1 file changed, 561 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 269dfb39cd..873683b06d 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,571 +1,10 @@ -import math - -import numpy as np import torch -import torch.nn as nn import tqdm -from ...configuration_utils import ConfigMixin -from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - class LatentDiffusionUncondPipeline(DiffusionPipeline): def __init__(self, vqvae, unet, noise_scheduler): super().__init__() From 2ac9b026097d44d117887c61761e61ddbc681909 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:43:04 +0200 Subject: [PATCH 04/12] remove AutoencoderKL from pipe __init__ --- src/diffusers/pipelines/latent_diffusion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/latent_diffusion/__init__.py b/src/diffusers/pipelines/latent_diffusion/__init__.py index 614c3db4bd..5ce717caab 100644 --- a/src/diffusers/pipelines/latent_diffusion/__init__.py +++ b/src/diffusers/pipelines/latent_diffusion/__init__.py @@ -2,4 +2,4 @@ from ...utils import is_transformers_available if is_transformers_available(): - from .pipeline_latent_diffusion import AutoencoderKL, LatentDiffusionPipeline, LDMBertModel + from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel From 99568c5a39cb729591963c8a4bbfb16b7f76d253 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:53:58 +0200 Subject: [PATCH 05/12] cleanup vae file --- src/diffusers/models/vae.py | 38 ------------------------------------- 1 file changed, 38 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 2824462010..282cb14a92 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -82,44 +82,6 @@ class ResnetBlock(nn.Module): return x + h -# class AttnBlock(nn.Module): -# def __init__(self, in_channels): -# super().__init__() -# self.in_channels = in_channels - -# self.norm = Normalize(in_channels) -# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - -# def forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) - -# # compute attention -# b, c, h, w = q.shape -# q = q.reshape(b, c, h * w) -# q = q.permute(0, 2, 1) # b,hw,c -# k = k.reshape(b, c, h * w) # b,c,hw -# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] -# w_ = w_ * (int(c) ** (-0.5)) -# w_ = torch.nn.functional.softmax(w_, dim=2) - -# # attend to values -# v = v.reshape(b, c, h * w) -# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) -# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] -# h_ = h_.reshape(b, c, h, w) - -# h_ = self.proj_out(h_) - -# return x + h_ - - class Encoder(nn.Module): def __init__( self, From 0b7daa6de99b759f4c6c0b723716b70e4107a2c8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:56:19 +0200 Subject: [PATCH 06/12] add forward for vq model --- src/diffusers/models/vae.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 282cb14a92..1a4c8904e4 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -534,6 +534,11 @@ class VQModel(ModelMixin, ConfigMixin): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec + + def forward(self, x): + h = self.encode(x) + dec = self.decode(h) + return dec class AutoencoderKL(ModelMixin, ConfigMixin): From bae04ea9d891e22de9c66247d2429afebfa45af1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 12:34:24 +0200 Subject: [PATCH 07/12] add test for VQModel --- tests/test_modeling_utils.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2e0be1d334..f38e629d9e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers import ( + AutoencoderKL, BDDMPipeline, DDIMPipeline, DDIMScheduler, @@ -44,6 +45,7 @@ from diffusers import ( UNetGradTTSModel, UNetLDMModel, UNetModel, + VQModel, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -805,6 +807,82 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class VQModelTests(ModelTesterMixin, unittest.TestCase): + model_class = VQModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "out_ch": 3, + "num_res_blocks": 1, + "attn_resolutions": [], + "in_channels": 3, + "resolution": 32, + "z_channels": 3, + "n_embed": 256, + "embed_dim": 3, + "sane_index_shape": False, + "ch_mult": (1,), + "dropout": 0.0, + "double_z": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, + -0.4218]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models From 976173a4bfab7ba5a091a1901e7ebb9d69d2e32d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 12:34:28 +0200 Subject: [PATCH 08/12] style --- src/diffusers/models/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 1a4c8904e4..bfc4a96e00 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -534,7 +534,7 @@ class VQModel(ModelMixin, ConfigMixin): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec - + def forward(self, x): h = self.encode(x) dec = self.decode(h) From 333a8da678882930a3c981df09128c7d8ffd0212 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 13:52:04 +0200 Subject: [PATCH 09/12] add tests for AutoencoderKL --- src/diffusers/models/vae.py | 6 +-- tests/test_modeling_utils.py | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bfc4a96e00..6bd9a07099 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) + def forward(self, x, sample_posterior=False): + posterior = self.encode(x) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) - return dec, posterior + return dec diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f38e629d9e..c567fdbbdb 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -46,6 +46,7 @@ from diffusers import ( UNetLDMModel, UNetModel, VQModel, + AutoencoderKL, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "ch_mult": (1,), + "embed_dim": 4, + "in_channels": 3, + "num_res_blocks": 1, + "out_ch": 3, + "resolution": 32, + "z_channels": 4, + "attn_resolutions": [] + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image, sample_posterior=True) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, + 0.1750]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models From 65788e46edfb60a31782a2bda0ba01f594359785 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 15:12:58 +0200 Subject: [PATCH 10/12] add scaled_linear schedule in DDIM --- src/diffusers/schedulers/scheduling_ddim.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f626cb1ca5..70fe2490e4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -71,6 +71,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): if beta_schedule == "linear": self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(timesteps) From 859ffea2b1a307ea6a526b6dc92d6843b923ed39 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 15:25:51 +0200 Subject: [PATCH 11/12] add test for ldm uncond --- tests/test_modeling_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index c567fdbbdb..743bef0e0e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -34,6 +34,7 @@ from diffusers import ( GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, + LatentDiffusionUncondPipeline, NCSNpp, PNDMPipeline, PNDMScheduler, @@ -46,7 +47,6 @@ from diffusers import ( UNetLDMModel, UNetModel, VQModel, - AutoencoderKL, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -915,7 +915,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): "out_ch": 3, "resolution": 32, "z_channels": 4, - "attn_resolutions": [] + "attn_resolutions": [], } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -925,7 +925,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): def test_training(self): pass - + def test_from_pretrained_hub(self): model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) self.assertIsNotNone(model) @@ -1151,6 +1151,19 @@ class PipelineTesterMixin(unittest.TestCase): assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + @slow + def test_ldm_uncond(self): + ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256") + + generator = torch.manual_seed(0) + image = ldm(generator=generator, num_inference_steps=5) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) From 8cba133f36a7ed162b0c52c5078459b5ba91c1bb Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 29 Jun 2022 15:37:23 +0200 Subject: [PATCH 12/12] Add the model card template (#43) * add a metrics logger * fix LatentDiffusionUncondPipeline * add VQModel in init * add image logging to tensorboard * switch manual templates to the modelcards package * hide ldm example Co-authored-by: patil-suraj --- .../train_latent_text_to_image.py | 0 examples/train_unconditional.py | 137 +++++++++++------- setup.py | 4 + src/diffusers/dependency_versions_table.py | 1 + src/diffusers/hub_utils.py | 50 ++++--- .../latent_diffusion_uncond/__init__.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 9 ++ src/diffusers/schedulers/scheduling_ddpm.py | 3 +- src/diffusers/utils/model_card_template.md | 50 +++++++ 9 files changed, 180 insertions(+), 76 deletions(-) rename examples/{ => experimental}/train_latent_text_to_image.py (100%) create mode 100644 src/diffusers/utils/model_card_template.md diff --git a/examples/train_latent_text_to_image.py b/examples/experimental/train_latent_text_to_image.py similarity index 100% rename from examples/train_latent_text_to_image.py rename to examples/experimental/train_latent_text_to_image.py diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 5398c78268..0813cadc63 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -4,14 +4,13 @@ import os import torch import torch.nn.functional as F -import PIL.Image from accelerate import Accelerator +from accelerate.logging import get_logger from datasets import load_dataset -from diffusers import DDPMPipeline, DDPMScheduler, UNetModel +from diffusers import DDIMPipeline, DDIMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import logging from torchvision.transforms import ( CenterCrop, Compose, @@ -24,11 +23,12 @@ from torchvision.transforms import ( from tqdm.auto import tqdm -logger = logging.get_logger(__name__) +logger = get_logger(__name__) def main(args): - accelerator = Accelerator(mixed_precision=args.mixed_precision) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir) model = UNetModel( attn_resolutions=(16,), @@ -39,8 +39,14 @@ def main(args): resamp_with_conv=True, resolution=args.resolution, ) - noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) augmentations = Compose( [ @@ -58,12 +64,12 @@ def main(args): return {"input": images} dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) lr_scheduler = get_scheduler( - "linear", + args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.warmup_steps, + num_warmup_steps=args.lr_warmup_steps, num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) @@ -76,15 +82,19 @@ def main(args): if args.push_to_hub: repo = init_git_repo(args, at_init=True) + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + # Train! is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() if is_distributed else 1 - total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataloader.dataset)}") logger.info(f" Num Epochs = {args.num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.batch_size}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") @@ -92,65 +102,71 @@ def main(args): global_step = 0 for epoch in range(args.num_epochs): model.train() - with tqdm(total=len(train_dataloader), unit="ba") as pbar: - pbar.set_description(f"Epoch {epoch}") - for step, batch in enumerate(train_dataloader): - clean_images = batch["input"] - noise_samples = torch.randn(clean_images.shape).to(clean_images.device) - bsz = clean_images.shape[0] - timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + noise_samples = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() - # add noise onto the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) + # add noise onto the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps) - if step % args.gradient_accumulation_steps != 0: - with accelerator.no_sync(model): - output = model(noisy_images, timesteps) - # predict the noise residual - loss = F.mse_loss(output, noise_samples) - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - else: + if step % args.gradient_accumulation_steps != 0: + with accelerator.no_sync(model): output = model(noisy_images, timesteps) # predict the noise residual loss = F.mse_loss(output, noise_samples) loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step() - lr_scheduler.step() - ema_model.step(model, global_step) - optimizer.zero_grad() - pbar.update(1) - pbar.set_postfix( - loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay - ) - global_step += 1 + else: + output = model(noisy_images, timesteps) + # predict the noise residual + loss = F.mse_loss(output, noise_samples) + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + ema_model.step(model, global_step) + optimizer.zero_grad() + progress_bar.update(1) + progress_bar.set_postfix( + loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay + ) + accelerator.log( + { + "train_loss": loss.detach().item(), + "epoch": epoch, + "ema_decay": ema_model.decay, + "step": global_step, + }, + step=global_step, + ) + global_step += 1 + progress_bar.close() accelerator.wait_for_everyone() # Generate a sample image for visual inspection if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDPMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler + pipeline = DDIMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model), + noise_scheduler=noise_scheduler, ) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) - image = pipeline(generator=generator) + images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50) - # process image to PIL - image_processed = image.cpu().permute(0, 2, 3, 1) - image_processed = (image_processed + 1.0) * 127.5 - image_processed = image_processed.type(torch.uint8).numpy() - image_pil = PIL.Image.fromarray(image_processed[0]) + # denormalize the images and save to tensorboard + images_processed = (images.cpu() + 1.0) * 127.5 + images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy() - # save image - test_dir = os.path.join(args.output_dir, "test_samples") - os.makedirs(test_dir, exist_ok=True) - image_pil.save(f"{test_dir}/{epoch:04d}.png") + accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch) # save the model if args.push_to_hub: @@ -159,6 +175,8 @@ def main(args): pipeline.save_pretrained(args.output_dir) accelerator.wait_for_everyone() + accelerator.end_training() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -167,18 +185,25 @@ if __name__ == "__main__": parser.add_argument("--output_dir", type=str, default="ddpm-model") parser.add_argument("--overwrite_output_dir", action="store_true") parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--adam_beta1", type=float, default=0.95) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) + parser.add_argument("--adam_epsilon", type=float, default=1e-3) parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_power", type=float, default=3 / 4) - parser.add_argument("--ema_max_decay", type=float, default=0.999) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument("--logging_dir", type=str, default="logs") parser.add_argument( "--mixed_precision", type=str, diff --git a/setup.py b/setup.py index 97c7583b6d..7ccbd630e0 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,8 @@ _deps = [ "regex!=2019.12.17", "requests", "torch>=1.4", + "tensorboard", + "modelcards=0.1.4" ] # this is a lookup table with items like: @@ -172,6 +174,8 @@ install_requires = [ deps["requests"], deps["torch"], deps["Pillow"], + deps["tensorboard"], + deps["modelcards"], ] setup( diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 5793d2fc85..833f726179 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -13,4 +13,5 @@ deps = { "regex": "regex!=2019.12.17", "requests": "requests", "torch": "torch>=1.4", + "tensorboard": "tensorboard", } diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index 2ab2ff289a..8e22108318 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -19,9 +19,9 @@ import shutil from pathlib import Path from typing import Optional -import yaml from diffusers import DiffusionPipeline from huggingface_hub import HfFolder, Repository, whoami +from modelcards import CardData, ModelCard from .utils import logging @@ -29,10 +29,7 @@ from .utils import logging logger = logging.get_logger(__name__) -AUTOGENERATED_TRAINER_COMMENT = """ - -""" +MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): @@ -152,17 +149,36 @@ def create_model_card(args, model_name): if args.local_rank not in [-1, 0]: return - # TODO: replace this placeholder model card generation - model_card = "" + repo_name = get_full_repo_name(model_name, token=args.hub_token) - metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]} - metadata = yaml.dump(metadata, sort_keys=False) - if len(metadata) > 0: - model_card = f"---\n{metadata}---\n" + model_card = ModelCard.from_template( + card_data=CardData( # Card metadata object that will be converted to YAML block + language="en", + license="apache-2.0", + library_name="diffusers", + tags=[], + datasets=args.dataset, + metrics=[], + ), + template_path=MODEL_CARD_TEMPLATE_PATH, + model_name=model_name, + repo_name=repo_name, + dataset_name=args.dataset, + learning_rate=args.learning_rate, + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + adam_beta1=args.adam_beta1, + adam_beta2=args.adam_beta2, + adam_weight_decay=args.adam_weight_decay, + adam_epsilon=args.adam_epsilon, + lr_scheduler=args.lr_scheduler, + lr_warmup_steps=args.lr_warmup_steps, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + mixed_precision=args.mixed_precision, + ) - model_card += AUTOGENERATED_TRAINER_COMMENT - - model_card += f"\n# {model_name}\n\n" - - with open(os.path.join(args.output_dir, "README.md"), "w") as f: - f.write(model_card) + card_path = os.path.join(args.output_dir, "README.md") + model_card.save(card_path) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/src/diffusers/pipelines/latent_diffusion_uncond/__init__.py index 6fe1f6d655..fa6e24945a 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/__init__.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/__init__.py @@ -1 +1 @@ -from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline +from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline, VQModel diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 70fe2490e4..97b5880c0b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -145,5 +145,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample + def add_noise(self, original_samples, noise, timesteps): + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def __len__(self): return self.config.timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d4230ff069..f80476b9c5 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -17,7 +17,6 @@ import math import numpy as np -import torch from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -142,7 +141,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample - def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): + def add_noise(self, original_samples, noise, timesteps): sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/utils/model_card_template.md b/src/diffusers/utils/model_card_template.md new file mode 100644 index 0000000000..f19c85b0fc --- /dev/null +++ b/src/diffusers/utils/model_card_template.md @@ -0,0 +1,50 @@ +--- +{{ card_data }} +--- + + + +# {{ model_name | default("Diffusion Model") }} + +## Model description + +This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library +on the `{{ dataset_name }}` dataset. + +## Intended uses & limitations + +#### How to use + +```python +# TODO: add an example code snippet for running this diffusion pipeline +``` + +#### Limitations and bias + +[TODO: provide examples of latent issues and potential remediations] + +## Training data + +[TODO: describe the data used to train the model] + +### Training hyperparameters + +The following hyperparameters were used during training: +- learning_rate: {{ learning_rate }} +- train_batch_size: {{ train_batch_size }} +- eval_batch_size: {{ eval_batch_size }} +- gradient_accumulation_steps: {{ gradient_accumulation_steps }} +- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} +- lr_scheduler: {{ lr_scheduler }} +- lr_warmup_steps: {{ lr_warmup_steps }} +- ema_inv_gamma: {{ ema_inv_gamma }} +- ema_inv_gamma: {{ ema_power }} +- ema_inv_gamma: {{ ema_max_decay }} +- mixed_precision: {{ mixed_precision }} + +### Training results + +📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) + +