mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move IP Adapter Face ID to core (#7186)
* Switch to peft and multi proj layers * Move Face ID loading and inference to core --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from safetensors import safe_open
|
||||
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
@@ -228,6 +229,18 @@ class IPAdapterMixin:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Sets the conditioning scale between text and image.
|
||||
|
||||
@@ -27,6 +27,8 @@ from torch import nn
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
IPAdapterFaceIDPlusImageProjection,
|
||||
IPAdapterFullImageProjection,
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
@@ -756,6 +758,90 @@ class UNet2DConditionLoadersMixin:
|
||||
diffusers_name = diffusers_name.replace("proj.3", "norm")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
elif "perceiver_resampler.proj_in.weight" in state_dict:
|
||||
# IP-Adapter Face ID Plus
|
||||
id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
|
||||
embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
|
||||
hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
|
||||
heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
with init_context():
|
||||
image_projection = IPAdapterFaceIDPlusImageProjection(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
id_embeddings_dim=id_embeddings_dim,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("perceiver_resampler.", "")
|
||||
diffusers_name = diffusers_name.replace("0.to", "attn.to")
|
||||
diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
|
||||
|
||||
if "norm1" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
|
||||
elif "norm2" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
|
||||
elif "to_kv" in diffusers_name:
|
||||
v_chunk = value.chunk(2, dim=0)
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_out" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
||||
elif "proj.0.weight" == diffusers_name:
|
||||
updated_state_dict["proj.net.0.proj.weight"] = value
|
||||
elif "proj.0.bias" == diffusers_name:
|
||||
updated_state_dict["proj.net.0.proj.bias"] = value
|
||||
elif "proj.2.weight" == diffusers_name:
|
||||
updated_state_dict["proj.net.2.weight"] = value
|
||||
elif "proj.2.bias" == diffusers_name:
|
||||
updated_state_dict["proj.net.2.bias"] = value
|
||||
else:
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
elif "norm.weight" in state_dict:
|
||||
# IP-Adapter Face ID
|
||||
id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
|
||||
id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
|
||||
multiplier = id_embeddings_dim_out // id_embeddings_dim_in
|
||||
norm_layer = "norm.weight"
|
||||
cross_attention_dim = state_dict[norm_layer].shape[0]
|
||||
num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
|
||||
|
||||
with init_context():
|
||||
image_projection = IPAdapterFaceIDImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=id_embeddings_dim_in,
|
||||
mult=multiplier,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["latents"].shape[1]
|
||||
@@ -847,6 +933,7 @@ class UNet2DConditionLoadersMixin:
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
@@ -859,6 +946,12 @@ class UNet2DConditionLoadersMixin:
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Full Face
|
||||
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
|
||||
elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Face ID Plus
|
||||
num_image_text_embeds += [4]
|
||||
elif "norm.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Face ID
|
||||
num_image_text_embeds += [4]
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
||||
@@ -910,6 +1003,59 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
def _load_ip_adapter_loras(self, state_dicts):
|
||||
lora_dicts = {}
|
||||
for key_id, name in enumerate(self.attn_processors.keys()):
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
|
||||
if i not in lora_dicts:
|
||||
lora_dicts[i] = {}
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_k_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_q_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_v_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.up.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
return lora_dicts
|
||||
|
||||
|
||||
class FromOriginalUNetMixin:
|
||||
"""
|
||||
|
||||
@@ -472,6 +472,22 @@ class IPAdapterFullImageProjection(nn.Module):
|
||||
return self.norm(self.ff(image_embeds))
|
||||
|
||||
|
||||
class IPAdapterFaceIDImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor):
|
||||
x = self.ff(image_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
@@ -794,13 +810,14 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
----
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
Defaults to 16. num_queries (int):
|
||||
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
"""
|
||||
@@ -851,11 +868,8 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
----
|
||||
x (torch.Tensor): Input Tensor.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
@@ -875,6 +889,119 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(embed_dims)
|
||||
self.ln1 = nn.LayerNorm(embed_dims)
|
||||
self.attn = Attention(
|
||||
query_dim=embed_dims,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
out_bias=False,
|
||||
)
|
||||
self.ff = nn.Sequential(
|
||||
nn.LayerNorm(embed_dims),
|
||||
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x, latents, residual):
|
||||
encoder_hidden_states = self.ln0(x)
|
||||
latents = self.ln1(latents)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
||||
latents = self.attn(latents, encoder_hidden_states) + residual
|
||||
latents = self.ff(latents) + latents
|
||||
return latents
|
||||
|
||||
|
||||
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
"""FacePerceiverResampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels (for ID embeddings). Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
output_dims: int = 768,
|
||||
hidden_dims: int = 1280,
|
||||
id_embeddings_dim: int = 512,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_tokens: int = 4,
|
||||
num_queries: int = 8,
|
||||
ffn_ratio: float = 4,
|
||||
ffproj_ratio: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.embed_dim = embed_dims
|
||||
self.clip_embeds = None
|
||||
self.shortcut = False
|
||||
self.shortcut_scale = 1.0
|
||||
|
||||
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
||||
self.norm = nn.LayerNorm(embed_dims)
|
||||
|
||||
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
||||
|
||||
self.proj_out = nn.Linear(embed_dims, output_dims)
|
||||
self.norm_out = nn.LayerNorm(output_dims)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
||||
)
|
||||
|
||||
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
||||
Returns:
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
||||
id_embeds = self.proj(id_embeds)
|
||||
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
||||
id_embeds = self.norm(id_embeds)
|
||||
latents = id_embeds
|
||||
|
||||
clip_embeds = self.proj_in(self.clip_embeds)
|
||||
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
||||
|
||||
for block in self.layers:
|
||||
residual = latents
|
||||
latents = block(x, latents, residual)
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
out = self.norm_out(latents)
|
||||
if self.shortcut:
|
||||
out = id_embeds + self.shortcut_scale * out
|
||||
return out
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user