mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
merge unet attention into glide attention
This commit is contained in:
@@ -32,62 +32,6 @@ class LinearAttention(torch.nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
# unet.py
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalization(in_channels, swish=None, eps=1e-6)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
print("x", x.abs().sum())
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
print("hid_states shape", h_.shape)
|
||||
print("hid_states", h_.abs().sum())
|
||||
print("hid_states - 3 - 3", h_.view(h_.shape[0], h_.shape[1], -1)[:, :3, -3:])
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
print(self.q)
|
||||
print("q_shape", q.shape)
|
||||
print("q", q.abs().sum())
|
||||
# print("k_shape", k.shape)
|
||||
# print("k", k.abs().sum())
|
||||
# print("v_shape", v.shape)
|
||||
# print("v", v.abs().sum())
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
|
||||
print("weight", w_.abs().sum())
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
# unet_glide.py & unet_ldm.py
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
|
||||
@@ -32,7 +32,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, Upsample
|
||||
from .attention2d import AttnBlock, AttentionBlock
|
||||
from .attention2d import AttentionBlock
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
@@ -86,44 +86,6 @@ class ResnetBlock(nn.Module):
|
||||
return x + h
|
||||
|
||||
|
||||
#class AttnBlock(nn.Module):
|
||||
# def __init__(self, in_channels):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
#
|
||||
# self.norm = Normalize(in_channels)
|
||||
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
#
|
||||
# compute attention
|
||||
# b, c, h, w = q.shape
|
||||
# q = q.reshape(b, c, h * w)
|
||||
# q = q.permute(0, 2, 1) # b,hw,c
|
||||
# k = k.reshape(b, c, h * w) # b,c,hw
|
||||
# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
# w_ = w_ * (int(c) ** (-0.5))
|
||||
# w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
#
|
||||
# attend to values
|
||||
# v = v.reshape(b, c, h * w)
|
||||
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
# h_ = h_.reshape(b, c, h, w)
|
||||
#
|
||||
# h_ = self.proj_out(h_)
|
||||
#
|
||||
# return x + h_
|
||||
|
||||
|
||||
class UNetModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -186,7 +148,6 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
# attn.append(AttnBlock(block_in))
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
@@ -202,7 +163,6 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
# self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
@@ -228,7 +188,6 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
# attn.append(AttnBlock(block_in))
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
|
||||
@@ -858,25 +858,26 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
|
||||
expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
generator = torch.manual_seed(0)
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
|
||||
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@@ -895,7 +896,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
|
||||
[-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@@ -966,24 +967,22 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
|
||||
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2)
|
||||
|
||||
expected_image_sum = 3382810112.0
|
||||
expected_image_mean = 1075.366455078125
|
||||
expected_image_sum = 3382849024.0
|
||||
expected_image_mean = 1075.3788
|
||||
|
||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||
|
||||
@slow
|
||||
def test_score_sde_vp_pipeline(self):
|
||||
|
||||
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
|
||||
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user