diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention.py similarity index 100% rename from src/diffusers/models/attention2d.py rename to src/diffusers/models/attention.py diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index dc3201f18b..13765e1f8b 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -15,22 +15,12 @@ # helpers functions -import copy -import math -from pathlib import Path - import torch from torch import nn -from torch.cuda.amp import GradScaler, autocast -from torch.optim import Adam -from torch.utils import data - -from PIL import Image -from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: - # self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block]) - # h = self.down[i_level].attn_2[i_block](h) - h = self.down[i_level].attn[i_block](h) - # print("Result", (h - h_2).abs().sum()) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 0ffd20b35f..53763ddaa0 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 4854442ce1..bc0c1e7a22 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,9 +1,8 @@ import torch -from numpy import pad from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import LinearAttention +from .attention import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample @@ -55,32 +54,6 @@ class ResnetBlock(torch.nn.Module): return output -class old_LinearAttention(torch.nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super(LinearAttention, self).__init__() - self.heads = heads - self.dim_head = dim_head - hidden_dim = dim_head * heads - self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) - q, k, v = ( - qkv.reshape(b, 3, self.heads, self.dim_head, h, w) - .permute(1, 0, 2, 3, 4, 5) - .reshape(3, b, self.heads, self.dim_head, -1) - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) - out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) - return self.to_out(out) - - class Residual(torch.nn.Module): def __init__(self, fn): super(Residual, self).__init__() diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 5315285afa..0012886a5e 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import get_timestep_embedding from .resnet import Downsample, Upsample diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 692ef99576..17218a7a7e 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -26,7 +26,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin -from .attention2d import AttentionBlock +from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding