mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_transpose_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
|
||||
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -125,87 +125,7 @@ class Downsample(nn.Module):
|
||||
return self.down(x)
|
||||
|
||||
|
||||
class UNetUpsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GlideUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LDMUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GradTTSUpsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Upsample, self).__init__()
|
||||
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# TODO (patil-suraj): needs test
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
@@ -31,6 +31,7 @@ from tqdm import tqdm
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
@@ -42,20 +43,6 @@ def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
@@ -10,6 +10,7 @@ except:
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
@@ -17,15 +18,6 @@ class Mish(torch.nn.Module):
|
||||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Upsample, self).__init__()
|
||||
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Downsample, self).__init__()
|
||||
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
||||
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in),
|
||||
Upsample(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ except:
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Upsample
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -81,60 +82,62 @@ def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
# class LinearAttention(nn.Module):
|
||||
# def __init__(self, dim, heads=4, dim_head=32):
|
||||
# super().__init__()
|
||||
# self.heads = heads
|
||||
# hidden_dim = dim_head * heads
|
||||
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# b, c, h, w = x.shape
|
||||
# qkv = self.to_qkv(x)
|
||||
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# k = k.softmax(dim=-1)
|
||||
# context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
# out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
# return self.to_out(out)
|
||||
#
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
# class SpatialSelfAttention(nn.Module):
|
||||
# def __init__(self, in_channels):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
#
|
||||
# self.norm = Normalize(in_channels)
|
||||
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
#
|
||||
# compute attention
|
||||
# b, c, h, w = q.shape
|
||||
# q = rearrange(q, "b c h w -> b (h w) c")
|
||||
# k = rearrange(k, "b c h w -> b c (h w)")
|
||||
# w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
#
|
||||
# w_ = w_ * (int(c) ** (-0.5))
|
||||
# w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
#
|
||||
# attend to values
|
||||
# v = rearrange(v, "b c h w -> b c (h w)")
|
||||
# w_ = rearrange(w_, "b i j -> b j i")
|
||||
# h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
# h_ = self.proj_out(h_)
|
||||
#
|
||||
# return x + h_
|
||||
#
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@@ -377,35 +380,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
@@ -480,8 +454,8 @@ class ResBlock(TimestepBlock):
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
@@ -948,7 +922,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
@@ -22,6 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Upsample
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
|
||||
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
|
||||
1e-3,
|
||||
)
|
||||
|
||||
|
||||
class UpsampleBlockTests(unittest.TestCase):
|
||||
def test_upsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 64, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_transpose(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
from diffusers import ( # GradTTSPipeline,
|
||||
BDDMPipeline,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
@@ -30,7 +30,6 @@ from diffusers import (
|
||||
GlidePipeline,
|
||||
GlideSuperResUNetModel,
|
||||
GlideTextToImageUNetModel,
|
||||
GradTTSPipeline,
|
||||
GradTTSScheduler,
|
||||
LatentDiffusionPipeline,
|
||||
NCSNpp,
|
||||
@@ -511,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
def test_output_pretrained_spatial_transformer(self):
|
||||
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
context = torch.ones((1, 16, 64), dtype=torch.float32)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, context=context)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetGradTTSModel
|
||||
|
||||
Reference in New Issue
Block a user