1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Chong
2023-08-22 11:32:06 +08:00
parent af66e4819b
commit a554f5a7ee

View File

@@ -173,31 +173,28 @@ class FullAdapter_XL(nn.Module):
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
downscale_factor: int = 16,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.channels = channels
self.num_res_blocks = num_res_blocks
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
self.body = []
for i in range(len(channels)):
for j in range(num_res_blocks):
if (i == 2) and (j == 0):
self.body.append(
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True))
elif (i == 1) and (j == 0):
self.body.append(
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False))
else:
self.body.append(
AdapterResnetBlock_XL(channels[i], channels[i], down=False))
self.body = nn.ModuleList(self.body)
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
self.body = []
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
for i in range(len(channels)):
if i==1:
self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks))
elif i==2:
self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks, down=True))
else:
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
self.body = nn.ModuleList(self.body)
# XL has one fewer downsampling
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x)
@@ -205,42 +202,13 @@ class FullAdapter_XL(nn.Module):
features = []
for i in range(len(self.channels)):
for j in range(self.num_res_blocks):
idx = i * self.num_res_blocks + j
x = self.body[idx](x)
for block in self.body:
x = block(x)
features.append(x)
return features
class AdapterResnetBlock_XL(nn.Module):
def __init__(self, channels_in, channels_out, down=False):
super().__init__()
if channels_in != channels_out:
self.in_conv = nn.Conv2d(channels_in, channels_out, 1)
else:
self.in_conv = None
self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1)
self.downsample = None
if down:
self.downsample = Downsample2D(channels_in)
def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
if self.in_conv is not None:
x = self.in_conv(x)
h = x
h = self.block1(h)
h = self.act(h)
h = self.block2(h)
return h + x
class AdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
super().__init__()