diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3a281c1594..729bce548e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -668,7 +668,7 @@ class AdaLayerNorm(nn.Module): return x -class DualTransformer2DModel(nn.Module, ConfigMixin): +class DualTransformer2DModel(nn.Module): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. @@ -695,9 +695,6 @@ class DualTransformer2DModel(nn.Module, ConfigMixin): Configure if the TransformerBlocks' attention should contain a bias parameter. """ - config_name = CONFIG_NAME - - @register_to_config def __init__( self, num_attention_heads: int = 16, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 106caf5c1b..93ac157b2c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -97,7 +97,21 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): image_transformer = self.image_unet.get_submodule(parent_name)[index] text_transformer = self.text_unet.get_submodule(parent_name)[index] - dual_transformer = DualTransformer2DModel(**image_transformer.config) + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) for i, type in enumerate(condition_types): if type == "image": dual_transformer.transformers[i] = image_transformer