mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
T2I-Adapter-XL
This commit is contained in:
@@ -109,6 +109,8 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
|
||||
if adapter_type == "full_adapter":
|
||||
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
elif adapter_type == "full_adapter_xl":
|
||||
self.adapter = FullAdapter_XL(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
elif adapter_type == "light_adapter":
|
||||
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
else:
|
||||
@@ -165,6 +167,80 @@ class FullAdapter(nn.Module):
|
||||
return features
|
||||
|
||||
|
||||
class FullAdapter_XL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_channels = in_channels * downscale_factor**2
|
||||
self.channels = channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
||||
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
||||
self.body = []
|
||||
for i in range(len(channels)):
|
||||
for j in range(num_res_blocks):
|
||||
if (i == 2) and (j == 0):
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=True))
|
||||
elif (i == 1) and (j == 0):
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i - 1], channels[i], down=False))
|
||||
else:
|
||||
self.body.append(
|
||||
AdapterResnetBlock_XL(channels[i], channels[i], down=False))
|
||||
self.body = nn.ModuleList(self.body)
|
||||
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
features = []
|
||||
|
||||
for i in range(len(self.channels)):
|
||||
for j in range(self.num_res_blocks):
|
||||
idx = i * self.num_res_blocks + j
|
||||
x = self.body[idx](x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class AdapterResnetBlock_XL(nn.Module):
|
||||
def __init__(self, channels_in, channels_out, down=False):
|
||||
super().__init__()
|
||||
if channels_in != channels_out:
|
||||
self.in_conv = nn.Conv2d(channels_in, channels_out, 1)
|
||||
else:
|
||||
self.in_conv = None
|
||||
self.block1 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels_out, channels_out, kernel_size=1)
|
||||
self.downsample = None
|
||||
if down:
|
||||
self.downsample = Downsample2D(channels_in)
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
if self.in_conv is not None:
|
||||
x = self.in_conv(x)
|
||||
h = x
|
||||
h = self.block1(h)
|
||||
h = self.act(h)
|
||||
h = self.block2(h)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class AdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
super().__init__()
|
||||
|
||||
@@ -965,6 +965,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
# To support T2I-Adapter-XL
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
Reference in New Issue
Block a user