From 36e1893c6f3d171bfff288ef5fa3f66fba2bc177 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:38:38 +0200 Subject: [PATCH] remove vae from ldm pipeline --- .../pipeline_latent_diffusion.py | 849 ------------------ 1 file changed, 849 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f7dc8ea343..6a4eb1621f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,4 +1,3 @@ -import math from typing import Optional, Tuple, Union import numpy as np @@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from ...configuration_utils import ConfigMixin -from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline @@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel): return sequence_output -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal - embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section - 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - ): - super().__init__() - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, x, t=None): - # assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits == False, "Only for interface compatible with Gumbel" - assert return_logits == False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class AutoencoderKL(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register_to_config( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - class LatentDiffusionPipeline(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__()