1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

T2I-Adapter-XL

This commit is contained in:
Chong
2023-08-21 22:16:21 +08:00
parent 7a24977ce3
commit af66e4819b
2 changed files with 79 additions and 0 deletions

View File

@@ -109,6 +109,8 @@ class T2IAdapter(ModelMixin, ConfigMixin):
if adapter_type == "full_adapter":
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
elif adapter_type == "full_adapter_xl":
self.adapter = FullAdapter_XL(in_channels, channels, num_res_blocks, downscale_factor)
elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else:
@@ -165,6 +167,80 @@ class FullAdapter(nn.Module):
return features
class FullAdapter_XL(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.channels = channels
self.num_res_blocks = num_res_blocks
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
self.body = []
for i in range(len(channels)):
for j in range(num_res_blocks):
if (i == 2) and (j == 0):
self.body.append(
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True))
elif (i == 1) and (j == 0):
self.body.append(
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False))
else:
self.body.append(
AdapterResnetBlock_XL(channels[i], channels[i], down=False))
self.body = nn.ModuleList(self.body)
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x)
x = self.conv_in(x)
features = []
for i in range(len(self.channels)):
for j in range(self.num_res_blocks):
idx = i * self.num_res_blocks + j
x = self.body[idx](x)
features.append(x)
return features
class AdapterResnetBlock_XL(nn.Module):
def __init__(self, channels_in, channels_out, down=False):
super().__init__()
if channels_in != channels_out:
self.in_conv = nn.Conv2d(channels_in, channels_out, 1)
else:
self.in_conv = None
self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1)
self.downsample = None
if down:
self.downsample = Downsample2D(channels_in)
def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
if self.in_conv is not None:
x = self.in_conv(x)
h = x
h = self.block1(h)
h = self.act(h)
h = self.block2(h)
return h + x
class AdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
super().__init__()

View File

@@ -965,6 +965,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# To support T2I-Adapter-XL
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual