diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f48a94039e..bad14f7e2a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from functools import partial import numpy as np import torch @@ -78,18 +79,24 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims self.use_conv_transpose = use_conv_transpose + name = self.name if use_conv_transpose: - self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) + conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) elif use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv def forward(self, x): assert x.shape[1] == self.channels @@ -102,7 +109,10 @@ class Upsample(nn.Module): x = F.interpolate(x, scale_factor=2.0, mode="nearest") if self.use_conv: - x = self.conv(x) + if self.name == "conv": + x = self.conv(x) + else: + x = self.Conv2d_0(x) return x @@ -134,6 +144,8 @@ class Downsample(nn.Module): if name == "conv": self.conv = conv + elif name == "Conv2d_0": + self.Conv2d_0 = conv else: self.op = conv @@ -145,6 +157,8 @@ class Downsample(nn.Module): if self.name == "conv": return self.conv(x) + elif self.name == "Conv2d_0": + return self.Conv2d_0(x) else: return self.op(x) @@ -390,6 +404,7 @@ class ResnetBlockBigGANpp(nn.Module): up=False, down=False, dropout=0.1, + fir=False, fir_kernel=(1, 3, 3, 1), skip_rescale=True, init_scale=0.0, @@ -400,8 +415,20 @@ class ResnetBlockBigGANpp(nn.Module): self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.up = up self.down = down + self.fir = fir self.fir_kernel = fir_kernel + if self.up: + if self.fir: + self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2) + else: + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + elif self.down: + if self.fir: + self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2) + else: + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) @@ -424,11 +451,11 @@ class ResnetBlockBigGANpp(nn.Module): h = self.act(self.GroupNorm_0(x)) if self.up: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) + h = self.upsample(h) + x = self.upsample(x) elif self.down: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) + h = self.downsample(h) + x = self.downsample(x) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 9c82e53e70..d9a4732f0b 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -17,6 +17,7 @@ import functools import math +from unicodedata import name import numpy as np import torch @@ -27,7 +28,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding -from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d +from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d def _setup_kernel(k): @@ -184,17 +185,17 @@ class Combine(nn.Module): class FirUpsample(nn.Module): - def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() - out_ch = out_ch if out_ch else in_ch - if with_conv: - self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) - self.with_conv = with_conv + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv self.fir_kernel = fir_kernel - self.out_ch = out_ch + self.out_channels = out_channels def forward(self, x): - if self.with_conv: + if self.use_conv: h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: h = upsample_2d(x, self.fir_kernel, factor=2) @@ -203,17 +204,17 @@ class FirUpsample(nn.Module): class FirDownsample(nn.Module): - def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() - out_ch = out_ch if out_ch else in_ch - if with_conv: - self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel - self.with_conv = with_conv - self.out_ch = out_ch + self.use_conv = use_conv + self.out_channels = out_channels def forward(self, x): - if self.with_conv: + if self.use_conv: x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) else: x = downsample_2d(x, self.fir_kernel, factor=2) @@ -234,7 +235,7 @@ class NCSNpp(ModelMixin, ConfigMixin): conv_size=3, dropout=0.0, embedding_type="fourier", - fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs + fir=True, fir_kernel=(1, 3, 3, 1), fourier_scale=16, init_scale=0.0, @@ -258,6 +259,7 @@ class NCSNpp(ModelMixin, ConfigMixin): conv_size=conv_size, dropout=dropout, embedding_type=embedding_type, + fir=fir, fir_kernel=fir_kernel, fourier_scale=fourier_scale, init_scale=init_scale, @@ -307,24 +309,33 @@ class NCSNpp(ModelMixin, ConfigMixin): modules.append(Linear(nf * 4, nf * 4)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) - Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + + if self.fir: + Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel) + else: + Up_sample = functools.partial(Upsample, name="Conv2d_0") if progressive == "output_skip": - self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) + self.pyramid_upsample = Up_sample(channels=None, use_conv=False) elif progressive == "residual": - pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) + pyramid_upsample = functools.partial(Up_sample, use_conv=True) - Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) + if self.fir: + Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel) + else: + print("fir false") + Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") if progressive_input == "input_skip": - self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) + self.pyramid_downsample = Down_sample(channels=None, use_conv=False) elif progressive_input == "residual": - pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) + pyramid_downsample = functools.partial(Down_sample, use_conv=True) ResnetBlock = functools.partial( ResnetBlockBigGANpp, act=act, dropout=dropout, + fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale, @@ -361,7 +372,7 @@ class NCSNpp(ModelMixin, ConfigMixin): in_ch *= 2 elif progressive_input == "residual": - modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch)) input_pyramid_ch = in_ch hs_c.append(in_ch) @@ -402,7 +413,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ) pyramid_ch = channels elif progressive == "residual": - modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch)) pyramid_ch = in_ch else: raise ValueError(f"{progressive} is not a valid name")