1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix AnimateDiff creation with a unet loaded with IP Adapter (#7791)

* Fix loading from_pipe

* Fix style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Fabio Rigano
2024-05-13 20:15:01 +02:00
committed by GitHub
parent fdb05f54ef
commit 44aa9e566d

View File

@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
@@ -27,6 +28,9 @@ from ..attention_processor import (
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
@@ -490,6 +494,36 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
model.time_proj.load_state_dict(unet.time_proj.state_dict())
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if any(
isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
for proc in unet.attn_processors.values()
):
attn_procs = {}
for name, processor in unet.attn_processors.items():
if name.endswith("attn1.processor"):
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=processor.hidden_size,
cross_attention_dim=processor.cross_attention_dim,
scale=processor.scale,
num_tokens=processor.num_tokens,
)
for name, processor in model.attn_processors.items():
if name not in attn_procs:
attn_procs[name] = processor.__class__()
model.set_attn_processor(attn_procs)
model.config.encoder_hid_dim_type = "ip_image_proj"
model.encoder_hid_proj = unet.encoder_hid_proj
for i, down_block in enumerate(unet.down_blocks):
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
if hasattr(model.down_blocks[i], "attentions"):