mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
DualTransformer(nn.Module)
This commit is contained in:
@@ -668,7 +668,7 @@ class AdaLayerNorm(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module, ConfigMixin):
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
@@ -695,9 +695,6 @@ class DualTransformer2DModel(nn.Module, ConfigMixin):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
config_name = CONFIG_NAME
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
|
||||
@@ -97,7 +97,21 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
image_transformer = self.image_unet.get_submodule(parent_name)[index]
|
||||
text_transformer = self.text_unet.get_submodule(parent_name)[index]
|
||||
|
||||
dual_transformer = DualTransformer2DModel(**image_transformer.config)
|
||||
config = image_transformer.config
|
||||
dual_transformer = DualTransformer2DModel(
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
attention_head_dim=config.attention_head_dim,
|
||||
in_channels=config.in_channels,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout,
|
||||
norm_num_groups=config.norm_num_groups,
|
||||
cross_attention_dim=config.cross_attention_dim,
|
||||
attention_bias=config.attention_bias,
|
||||
sample_size=config.sample_size,
|
||||
num_vector_embeds=config.num_vector_embeds,
|
||||
activation_fn=config.activation_fn,
|
||||
num_embeds_ada_norm=config.num_embeds_ada_norm,
|
||||
)
|
||||
for i, type in enumerate(condition_types):
|
||||
if type == "image":
|
||||
dual_transformer.transformers[i] = image_transformer
|
||||
|
||||
Reference in New Issue
Block a user