From 99568c5a39cb729591963c8a4bbfb16b7f76d253 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 11:53:58 +0200 Subject: [PATCH] cleanup vae file --- src/diffusers/models/vae.py | 38 ------------------------------------- 1 file changed, 38 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 2824462010..282cb14a92 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -82,44 +82,6 @@ class ResnetBlock(nn.Module): return x + h -# class AttnBlock(nn.Module): -# def __init__(self, in_channels): -# super().__init__() -# self.in_channels = in_channels - -# self.norm = Normalize(in_channels) -# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) -# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - -# def forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) - -# # compute attention -# b, c, h, w = q.shape -# q = q.reshape(b, c, h * w) -# q = q.permute(0, 2, 1) # b,hw,c -# k = k.reshape(b, c, h * w) # b,c,hw -# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] -# w_ = w_ * (int(c) ** (-0.5)) -# w_ = torch.nn.functional.softmax(w_, dim=2) - -# # attend to values -# v = v.reshape(b, c, h * w) -# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) -# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] -# h_ = h_.reshape(b, c, h, w) - -# h_ = self.proj_out(h_) - -# return x + h_ - - class Encoder(nn.Module): def __init__( self,