mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add fir=False back
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -78,18 +79,24 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
name = self.name
|
||||
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
@@ -102,7 +109,10 @@ class Upsample(nn.Module):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
if self.name == "conv":
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -134,6 +144,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.Conv2d_0 = conv
|
||||
else:
|
||||
self.op = conv
|
||||
|
||||
@@ -145,6 +157,8 @@ class Downsample(nn.Module):
|
||||
|
||||
if self.name == "conv":
|
||||
return self.conv(x)
|
||||
elif self.name == "Conv2d_0":
|
||||
return self.Conv2d_0(x)
|
||||
else:
|
||||
return self.op(x)
|
||||
|
||||
@@ -390,6 +404,7 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
up=False,
|
||||
down=False,
|
||||
dropout=0.1,
|
||||
fir=False,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
skip_rescale=True,
|
||||
init_scale=0.0,
|
||||
@@ -400,8 +415,20 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
@@ -424,11 +451,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
elif self.down:
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
h = self.downsample(h)
|
||||
x = self.downsample(x)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import functools
|
||||
import math
|
||||
from unicodedata import name
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,7 +28,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d
|
||||
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
@@ -184,17 +185,17 @@ class Combine(nn.Module):
|
||||
|
||||
|
||||
class FirUpsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
self.with_conv = with_conv
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_ch = out_ch
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if self.use_conv:
|
||||
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
@@ -203,17 +204,17 @@ class FirUpsample(nn.Module):
|
||||
|
||||
|
||||
class FirDownsample(nn.Module):
|
||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if with_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
self.out_ch = out_ch
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
if self.use_conv:
|
||||
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
||||
else:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
@@ -234,7 +235,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
conv_size=3,
|
||||
dropout=0.0,
|
||||
embedding_type="fourier",
|
||||
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
|
||||
fir=True,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
fourier_scale=16,
|
||||
init_scale=0.0,
|
||||
@@ -258,6 +259,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
conv_size=conv_size,
|
||||
dropout=dropout,
|
||||
embedding_type=embedding_type,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
fourier_scale=fourier_scale,
|
||||
init_scale=init_scale,
|
||||
@@ -307,24 +309,33 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
modules.append(Linear(nf * 4, nf * 4))
|
||||
|
||||
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
|
||||
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
|
||||
if self.fir:
|
||||
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel)
|
||||
else:
|
||||
Up_sample = functools.partial(Upsample, name="Conv2d_0")
|
||||
|
||||
if progressive == "output_skip":
|
||||
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
|
||||
elif progressive == "residual":
|
||||
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
|
||||
|
||||
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
|
||||
if self.fir:
|
||||
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel)
|
||||
else:
|
||||
print("fir false")
|
||||
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
|
||||
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
|
||||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
|
||||
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
|
||||
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
@@ -361,7 +372,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
in_ch *= 2
|
||||
|
||||
elif progressive_input == "residual":
|
||||
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
|
||||
input_pyramid_ch = in_ch
|
||||
|
||||
hs_c.append(in_ch)
|
||||
@@ -402,7 +413,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
)
|
||||
pyramid_ch = channels
|
||||
elif progressive == "residual":
|
||||
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
||||
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
|
||||
pyramid_ch = in_ch
|
||||
else:
|
||||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
Reference in New Issue
Block a user