diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index a092daa662..1b62d16d5d 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config @@ -27,6 +28,9 @@ from ..attention_processor import ( AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @@ -490,6 +494,36 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): model.time_proj.load_state_dict(unet.time_proj.state_dict()) model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + if any( + isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + for proc in unet.attn_processors.values() + ): + attn_procs = {} + for name, processor in unet.attn_processors.items(): + if name.endswith("attn1.processor"): + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=processor.hidden_size, + cross_attention_dim=processor.cross_attention_dim, + scale=processor.scale, + num_tokens=processor.num_tokens, + ) + for name, processor in model.attn_processors.items(): + if name not in attn_procs: + attn_procs[name] = processor.__class__() + model.set_attn_processor(attn_procs) + model.config.encoder_hid_dim_type = "ip_image_proj" + model.encoder_hid_proj = unet.encoder_hid_proj + for i, down_block in enumerate(unet.down_blocks): model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) if hasattr(model.down_blocks[i], "attentions"):