From a554f5a7ee64e7b1d3e62e508b5701ebfb08331b Mon Sep 17 00:00:00 2001 From: Chong Date: Tue, 22 Aug 2023 11:32:06 +0800 Subject: [PATCH] update --- src/diffusers/models/adapter.py | 64 +++++++++------------------------ 1 file changed, 16 insertions(+), 48 deletions(-) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index e134807ec4..04cf9d9ba1 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -173,31 +173,28 @@ class FullAdapter_XL(nn.Module): in_channels: int = 3, channels: List[int] = [320, 640, 1280, 1280], num_res_blocks: int = 2, - downscale_factor: int = 8, + downscale_factor: int = 16, ): super().__init__() in_channels = in_channels * downscale_factor**2 - self.channels = channels - self.num_res_blocks = num_res_blocks self.unshuffle = nn.PixelUnshuffle(downscale_factor) self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) - self.body = [] - for i in range(len(channels)): - for j in range(num_res_blocks): - if (i == 2) and (j == 0): - self.body.append( - AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True)) - elif (i == 1) and (j == 0): - self.body.append( - AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False)) - else: - self.body.append( - AdapterResnetBlock_XL(channels[i], channels[i], down=False)) - self.body = nn.ModuleList(self.body) - self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) + self.body = [] + # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32] + for i in range(len(channels)): + if i==1: + self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks)) + elif i==2: + self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks, down=True)) + else: + self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks)) + + self.body = nn.ModuleList(self.body) + # XL has one fewer downsampling + self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.unshuffle(x) @@ -205,42 +202,13 @@ class FullAdapter_XL(nn.Module): features = [] - for i in range(len(self.channels)): - for j in range(self.num_res_blocks): - idx = i * self.num_res_blocks + j - x = self.body[idx](x) + for block in self.body: + x = block(x) features.append(x) return features -class AdapterResnetBlock_XL(nn.Module): - def __init__(self, channels_in, channels_out, down=False): - super().__init__() - if channels_in != channels_out: - self.in_conv = nn.Conv2d(channels_in, channels_out, 1) - else: - self.in_conv = None - self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1) - self.act = nn.ReLU() - self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1) - self.downsample = None - if down: - self.downsample = Downsample2D(channels_in) - - def forward(self, x): - if self.downsample is not None: - x = self.downsample(x) - if self.in_conv is not None: - x = self.in_conv(x) - h = x - h = self.block1(h) - h = self.act(h) - h = self.block2(h) - - return h + x - - class AdapterBlock(nn.Module): def __init__(self, in_channels, out_channels, num_res_blocks, down=False): super().__init__()