diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index c2a52a3515..a060dc1bbe 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -114,6 +114,7 @@ def get_down_block_adapter( cross_attention_dim: Optional[int] = 1024, add_downsample: bool = True, upcast_attention: Optional[bool] = False, + use_linear_projection: Optional[bool] = True, ): num_layers = 2 # only support sd + sdxl @@ -152,7 +153,7 @@ def get_down_block_adapter( in_channels=ctrl_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), ) @@ -200,6 +201,7 @@ def get_mid_block_adapter( num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, + use_linear_projection: bool = True, ): # Before the midblock application, information is concatted from base to control. # Concat doesn't require change in number of channels @@ -214,7 +216,7 @@ def get_mid_block_adapter( resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention: bool = True, max_norm_num_groups: int = 32, + use_linear_projection: bool = True, ): super().__init__() @@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) ) @@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) # up @@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): transformer_layers_per_block=unet.config.transformer_layers_per_block, upcast_attention=unet.config.upcast_attention, max_norm_num_groups=unet.config.norm_num_groups, + use_linear_projection=unet.config.use_linear_projection, ) # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel @@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, upcast_attention: bool = True, + use_linear_projection: bool = True, time_cond_proj_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, # additional controlnet configs @@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): time_embed_dim, cond_proj_dim=time_cond_proj_dim, ) - self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) + if ctrl_learn_time_embedding: + self.ctrl_time_embedding = TimestepEmbedding( + in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim + ) + else: + self.ctrl_time_embedding = None if addition_embed_type is None: self.base_add_time_proj = None @@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) ) @@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ctrl_num_attention_heads=ctrl_num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) # # Create up blocks @@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): add_upsample=not is_final_block, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, + use_linear_projection=use_linear_projection, ) ) @@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): "addition_embed_type", "addition_time_embed_dim", "upcast_attention", + "use_linear_projection", "time_cond_proj_dim", "projection_class_embeddings_input_dim", ] @@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): cross_attention_dim: Optional[int] = 1024, add_downsample: bool = True, upcast_attention: Optional[bool] = False, + use_linear_projection: Optional[bool] = True, ): super().__init__() base_resnets = [] @@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): in_channels=base_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, ) @@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): in_channels=ctrl_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), ) @@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + use_linear_projection = base_downblock.attentions[0].use_linear_projection else: has_crossattn = False transformer_layers_per_block = None @@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ctrl_num_attention_heads = None cross_attention_dim = None upcast_attention = None + use_linear_projection = None add_downsample = base_downblock.downsamplers is not None # create model @@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): cross_attention_dim=cross_attention_dim, add_downsample=add_downsample, upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) # # load weights @@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ctrl_num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, ): super().__init__() @@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=base_num_attention_heads, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ), cross_attention_dim=cross_attention_dim, num_attention_heads=ctrl_num_attention_heads, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + use_linear_projection = base_midblock.attentions[0].use_linear_projection # create model model = cls( @@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ctrl_num_attention_heads=ctrl_num_attention_heads, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) # load weights @@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): cross_attention_dim: int = 1024, add_upsample: bool = True, upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, ): super().__init__() resnets = [] @@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - use_linear_projection=True, + use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, ) @@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): num_attention_heads = get_first_cross_attention(base_upblock).heads cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + use_linear_projection = base_upblock.attentions[0].use_linear_projection else: has_crossattn = False transformer_layers_per_block = None num_attention_heads = None cross_attention_dim = None upcast_attention = None + use_linear_projection = None add_upsample = base_upblock.upsamplers is not None # create model @@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): cross_attention_dim=cross_attention_dim, add_upsample=add_upsample, upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, ) # load weights