From af66e4819b043633785be7225eb7dcce917c5ace Mon Sep 17 00:00:00 2001 From: Chong Date: Mon, 21 Aug 2023 22:16:21 +0800 Subject: [PATCH] T2I-Adapter-XL --- src/diffusers/models/adapter.py | 76 +++++++++++++++++++++++ src/diffusers/models/unet_2d_condition.py | 3 + 2 files changed, 79 insertions(+) diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index a65a3873b1..e134807ec4 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -109,6 +109,8 @@ class T2IAdapter(ModelMixin, ConfigMixin): if adapter_type == "full_adapter": self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor) + elif adapter_type == "full_adapter_xl": + self.adapter = FullAdapter_XL(in_channels, channels, num_res_blocks, downscale_factor) elif adapter_type == "light_adapter": self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) else: @@ -165,6 +167,80 @@ class FullAdapter(nn.Module): return features +class FullAdapter_XL(nn.Module): + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + self.channels = channels + self.num_res_blocks = num_res_blocks + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) + self.body = [] + for i in range(len(channels)): + for j in range(num_res_blocks): + if (i == 2) and (j == 0): + self.body.append( + AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True)) + elif (i == 1) and (j == 0): + self.body.append( + AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False)) + else: + self.body.append( + AdapterResnetBlock_XL(channels[i], channels[i], down=False)) + self.body = nn.ModuleList(self.body) + + self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.unshuffle(x) + x = self.conv_in(x) + + features = [] + + for i in range(len(self.channels)): + for j in range(self.num_res_blocks): + idx = i * self.num_res_blocks + j + x = self.body[idx](x) + features.append(x) + + return features + + +class AdapterResnetBlock_XL(nn.Module): + def __init__(self, channels_in, channels_out, down=False): + super().__init__() + if channels_in != channels_out: + self.in_conv = nn.Conv2d(channels_in, channels_out, 1) + else: + self.in_conv = None + self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1) + self.downsample = None + if down: + self.downsample = Downsample2D(channels_in) + + def forward(self, x): + if self.downsample is not None: + x = self.downsample(x) + if self.in_conv is not None: + x = self.in_conv(x) + h = x + h = self.block1(h) + h = self.act(h) + h = self.block2(h) + + return h + x + + class AdapterBlock(nn.Module): def __init__(self, in_channels, out_channels, num_res_blocks, down=False): super().__init__() diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3203537110..b43c40a515 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -965,6 +965,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) + # To support T2I-Adapter-XL + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual