1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

DualTransformer(nn.Module)

This commit is contained in:
anton-l
2022-11-23 12:16:41 +01:00
parent 8f5f372573
commit f5e8ec6179
2 changed files with 16 additions and 5 deletions

View File

@@ -668,7 +668,7 @@ class AdaLayerNorm(nn.Module):
return x
class DualTransformer2DModel(nn.Module, ConfigMixin):
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
@@ -695,9 +695,6 @@ class DualTransformer2DModel(nn.Module, ConfigMixin):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,

View File

@@ -97,7 +97,21 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_transformer = self.image_unet.get_submodule(parent_name)[index]
text_transformer = self.text_unet.get_submodule(parent_name)[index]
dual_transformer = DualTransformer2DModel(**image_transformer.config)
config = image_transformer.config
dual_transformer = DualTransformer2DModel(
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
in_channels=config.in_channels,
num_layers=config.num_layers,
dropout=config.dropout,
norm_num_groups=config.norm_num_groups,
cross_attention_dim=config.cross_attention_dim,
attention_bias=config.attention_bias,
sample_size=config.sample_size,
num_vector_embeds=config.num_vector_embeds,
activation_fn=config.activation_fn,
num_embeds_ada_norm=config.num_embeds_ada_norm,
)
for i, type in enumerate(condition_types):
if type == "image":
dual_transformer.transformers[i] = image_transformer