mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add sd1.5 compatibility to controlnet-xs and fix unused_parameters error during training (#8606)
* add sd1.5 compatibility to controlnet-xs * set use_linear_projection by base_block * refine code style
This commit is contained in:
@@ -114,6 +114,7 @@ def get_down_block_adapter(
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
add_downsample: bool = True,
|
||||
upcast_attention: Optional[bool] = False,
|
||||
use_linear_projection: Optional[bool] = True,
|
||||
):
|
||||
num_layers = 2 # only support sd + sdxl
|
||||
|
||||
@@ -152,7 +153,7 @@ def get_down_block_adapter(
|
||||
in_channels=ctrl_out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
|
||||
)
|
||||
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
upcast_attention: bool = False,
|
||||
use_linear_projection: bool = True,
|
||||
):
|
||||
# Before the midblock application, information is concatted from base to control.
|
||||
# Concat doesn't require change in number of channels
|
||||
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
|
||||
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
upcast_attention: bool = True,
|
||||
max_norm_num_groups: int = 32,
|
||||
use_linear_projection: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# up
|
||||
@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
transformer_layers_per_block=unet.config.transformer_layers_per_block,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
max_norm_num_groups=unet.config.norm_num_groups,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
)
|
||||
|
||||
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
|
||||
@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
upcast_attention: bool = True,
|
||||
use_linear_projection: bool = True,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
# additional controlnet configs
|
||||
@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
time_embed_dim,
|
||||
cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
|
||||
if ctrl_learn_time_embedding:
|
||||
self.ctrl_time_embedding = TimestepEmbedding(
|
||||
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
else:
|
||||
self.ctrl_time_embedding = None
|
||||
|
||||
if addition_embed_type is None:
|
||||
self.base_add_time_proj = None
|
||||
@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# # Create up blocks
|
||||
@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
add_upsample=not is_final_block,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_num_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
"addition_embed_type",
|
||||
"addition_time_embed_dim",
|
||||
"upcast_attention",
|
||||
"use_linear_projection",
|
||||
"time_cond_proj_dim",
|
||||
"projection_class_embeddings_input_dim",
|
||||
]
|
||||
@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
add_downsample: bool = True,
|
||||
upcast_attention: Optional[bool] = False,
|
||||
use_linear_projection: Optional[bool] = True,
|
||||
):
|
||||
super().__init__()
|
||||
base_resnets = []
|
||||
@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
in_channels=base_out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_num_groups=norm_num_groups,
|
||||
)
|
||||
@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
in_channels=ctrl_out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
|
||||
)
|
||||
@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
|
||||
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
|
||||
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
|
||||
use_linear_projection = base_downblock.attentions[0].use_linear_projection
|
||||
else:
|
||||
has_crossattn = False
|
||||
transformer_layers_per_block = None
|
||||
@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
ctrl_num_attention_heads = None
|
||||
cross_attention_dim = None
|
||||
upcast_attention = None
|
||||
use_linear_projection = None
|
||||
add_downsample = base_downblock.downsamplers is not None
|
||||
|
||||
# create model
|
||||
@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
add_downsample=add_downsample,
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# # load weights
|
||||
@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
ctrl_num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
upcast_attention: bool = False,
|
||||
use_linear_projection: Optional[bool] = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=base_num_attention_heads,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=ctrl_num_attention_heads,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
|
||||
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
|
||||
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
|
||||
use_linear_projection = base_midblock.attentions[0].use_linear_projection
|
||||
|
||||
# create model
|
||||
model = cls(
|
||||
@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
ctrl_num_attention_heads=ctrl_num_attention_heads,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# load weights
|
||||
@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
cross_attention_dim: int = 1024,
|
||||
add_upsample: bool = True,
|
||||
upcast_attention: bool = False,
|
||||
use_linear_projection: Optional[bool] = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=True,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_num_groups=norm_num_groups,
|
||||
)
|
||||
@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
num_attention_heads = get_first_cross_attention(base_upblock).heads
|
||||
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
|
||||
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
|
||||
use_linear_projection = base_upblock.attentions[0].use_linear_projection
|
||||
else:
|
||||
has_crossattn = False
|
||||
transformer_layers_per_block = None
|
||||
num_attention_heads = None
|
||||
cross_attention_dim = None
|
||||
upcast_attention = None
|
||||
use_linear_projection = None
|
||||
add_upsample = base_upblock.upsamplers is not None
|
||||
|
||||
# create model
|
||||
@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
add_upsample=add_upsample,
|
||||
upcast_attention=upcast_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# load weights
|
||||
|
||||
Reference in New Issue
Block a user