1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add sd1.5 compatibility to controlnet-xs and fix unused_parameters error during training (#8606)

* add sd1.5 compatibility to controlnet-xs

* set use_linear_projection by base_block

* refine code style
This commit is contained in:
Yongsen Mao
2024-06-19 05:35:34 +08:00
committed by GitHub
parent 4408047ac5
commit 16170c69ae

View File

@@ -114,6 +114,7 @@ def get_down_block_adapter(
cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True,
upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
):
num_layers = 2 # only support sd + sdxl
@@ -152,7 +153,7 @@ def get_down_block_adapter(
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
)
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False,
use_linear_projection: bool = True,
):
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
upcast_attention: bool = True,
max_norm_num_groups: int = 32,
use_linear_projection: bool = True,
):
super().__init__()
@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
)
@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# up
@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block=unet.config.transformer_layers_per_block,
upcast_attention=unet.config.upcast_attention,
max_norm_num_groups=unet.config.norm_num_groups,
use_linear_projection=unet.config.use_linear_projection,
)
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
upcast_attention: bool = True,
use_linear_projection: bool = True,
time_cond_proj_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
# additional controlnet configs
@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
time_embed_dim,
cond_proj_dim=time_cond_proj_dim,
)
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
if ctrl_learn_time_embedding:
self.ctrl_time_embedding = TimestepEmbedding(
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
)
else:
self.ctrl_time_embedding = None
if addition_embed_type is None:
self.base_add_time_proj = None
@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
)
@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# # Create up blocks
@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
)
)
@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
"addition_embed_type",
"addition_time_embed_dim",
"upcast_attention",
"use_linear_projection",
"time_cond_proj_dim",
"projection_class_embeddings_input_dim",
]
@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True,
upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
):
super().__init__()
base_resnets = []
@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=base_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
)
@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
use_linear_projection = base_downblock.attentions[0].use_linear_projection
else:
has_crossattn = False
transformer_layers_per_block = None
@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
use_linear_projection = None
add_downsample = base_downblock.downsamplers is not None
# create model
@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim,
add_downsample=add_downsample,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# # load weights
@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
):
super().__init__()
@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=base_num_attention_heads,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
),
cross_attention_dim=cross_attention_dim,
num_attention_heads=ctrl_num_attention_heads,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
use_linear_projection = base_midblock.attentions[0].use_linear_projection
# create model
model = cls(
@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# load weights
@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim: int = 1024,
add_upsample: bool = True,
upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
):
super().__init__()
resnets = []
@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
num_attention_heads = get_first_cross_attention(base_upblock).heads
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
use_linear_projection = base_upblock.attentions[0].use_linear_projection
else:
has_crossattn = False
transformer_layers_per_block = None
num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
use_linear_projection = None
add_upsample = base_upblock.upsamplers is not None
# create model
@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample,
upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
)
# load weights