mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix AnimateDiff creation with a unet loaded with IP Adapter (#7791)
* Fix loading from_pipe * Fix style --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
@@ -27,6 +28,9 @@ from ..attention_processor import (
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -490,6 +494,36 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
model.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if any(
|
||||
isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
||||
for proc in unet.attn_processors.values()
|
||||
):
|
||||
attn_procs = {}
|
||||
for name, processor in unet.attn_processors.items():
|
||||
if name.endswith("attn1.processor"):
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else IPAdapterAttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=processor.hidden_size,
|
||||
cross_attention_dim=processor.cross_attention_dim,
|
||||
scale=processor.scale,
|
||||
num_tokens=processor.num_tokens,
|
||||
)
|
||||
for name, processor in model.attn_processors.items():
|
||||
if name not in attn_procs:
|
||||
attn_procs[name] = processor.__class__()
|
||||
model.set_attn_processor(attn_procs)
|
||||
model.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
model.encoder_hid_proj = unet.encoder_hid_proj
|
||||
|
||||
for i, down_block in enumerate(unet.down_blocks):
|
||||
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
|
||||
if hasattr(model.down_blocks[i], "attentions"):
|
||||
|
||||
Reference in New Issue
Block a user