mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -173,31 +173,28 @@ class FullAdapter_XL(nn.Module):
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_channels = in_channels * downscale_factor**2
|
||||
self.channels = channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
||||
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
||||
self.body = []
|
||||
for i in range(len(channels)):
|
||||
for j in range(num_res_blocks):
|
||||
if (i == 2) and (j == 0):
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True))
|
||||
elif (i == 1) and (j == 0):
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False))
|
||||
else:
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i], channels[i], down=False))
|
||||
self.body = nn.ModuleList(self.body)
|
||||
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
||||
self.body = []
|
||||
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
|
||||
for i in range(len(channels)):
|
||||
if i==1:
|
||||
self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks))
|
||||
elif i==2:
|
||||
self.body.append(AdapterBlock(channels[i-1], channels[i], num_res_blocks, down=True))
|
||||
else:
|
||||
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
|
||||
|
||||
self.body = nn.ModuleList(self.body)
|
||||
# XL has one fewer downsampling
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.unshuffle(x)
|
||||
@@ -205,42 +202,13 @@ class FullAdapter_XL(nn.Module):
|
||||
|
||||
features = []
|
||||
|
||||
for i in range(len(self.channels)):
|
||||
for j in range(self.num_res_blocks):
|
||||
idx = i * self.num_res_blocks + j
|
||||
x = self.body[idx](x)
|
||||
for block in self.body:
|
||||
x = block(x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class AdapterResnetBlock_XL(nn.Module):
|
||||
def __init__(self, channels_in, channels_out, down=False):
|
||||
super().__init__()
|
||||
if channels_in != channels_out:
|
||||
self.in_conv = nn.Conv2d(channels_in, channels_out, 1)
|
||||
else:
|
||||
self.in_conv = None
|
||||
self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1)
|
||||
self.downsample = None
|
||||
if down:
|
||||
self.downsample = Downsample2D(channels_in)
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
if self.in_conv is not None:
|
||||
x = self.in_conv(x)
|
||||
h = x
|
||||
h = self.block1(h)
|
||||
h = self.act(h)
|
||||
h = self.block2(h)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class AdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user