1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add fir=False back

This commit is contained in:
patil-suraj
2022-07-01 12:01:59 +02:00
parent abedfb08f1
commit 5018abff6e
2 changed files with 70 additions and 32 deletions

View File

@@ -1,4 +1,5 @@
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
@@ -78,18 +79,24 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
name = self.name
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x):
assert x.shape[1] == self.channels
@@ -102,7 +109,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x
@@ -134,6 +144,8 @@ class Downsample(nn.Module):
if name == "conv":
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else:
self.op = conv
@@ -145,6 +157,8 @@ class Downsample(nn.Module):
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else:
return self.op(x)
@@ -390,6 +404,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
@@ -400,8 +415,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
@@ -424,11 +451,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
h = self.upsample(h)
x = self.upsample(x)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.downsample(h)
x = self.downsample(x)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding

View File

@@ -17,6 +17,7 @@
import functools
import math
from unicodedata import name
import numpy as np
import torch
@@ -27,7 +28,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k):
@@ -184,17 +185,17 @@ class Combine(nn.Module):
class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
@@ -203,17 +204,17 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
@@ -234,7 +235,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
@@ -258,6 +259,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
@@ -307,24 +309,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel)
else:
print("fir false")
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
@@ -361,7 +372,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
@@ -402,7 +413,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")