mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] Refactor IPAdapterPlusImageProjection a bit (#7994)
* use IPAdapterPlusImageProjectionBlock in IPAdapterPlusImageProjection * reposition IPAdapterPlusImageProjection * refactor complete? * fix heads param retrieval. * update test dict creation method.
This commit is contained in:
@@ -847,7 +847,12 @@ class UNet2DConditionLoadersMixin:
|
||||
embed_dims = state_dict["proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
attn_key_present = any("attn" in k for k in state_dict)
|
||||
heads = (
|
||||
state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
if attn_key_present
|
||||
else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
)
|
||||
|
||||
with init_context():
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
@@ -860,26 +865,53 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("0.to", "2.to")
|
||||
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
|
||||
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
|
||||
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
|
||||
|
||||
if "norm1" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
|
||||
elif "norm2" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
|
||||
elif "to_kv" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
|
||||
diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
|
||||
diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
|
||||
diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
|
||||
diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
|
||||
diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
|
||||
diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
|
||||
diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
|
||||
|
||||
if "to_kv" in diffusers_name:
|
||||
parts = diffusers_name.split(".")
|
||||
parts[2] = "attn"
|
||||
diffusers_name = ".".join(parts)
|
||||
v_chunk = value.chunk(2, dim=0)
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_q" in diffusers_name:
|
||||
parts = diffusers_name.split(".")
|
||||
parts[2] = "attn"
|
||||
diffusers_name = ".".join(parts)
|
||||
updated_state_dict[diffusers_name] = value
|
||||
elif "to_out" in diffusers_name:
|
||||
parts = diffusers_name.split(".")
|
||||
parts[2] = "attn"
|
||||
diffusers_name = ".".join(parts)
|
||||
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
|
||||
diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
|
||||
|
||||
diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
|
||||
diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
|
||||
|
||||
diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
|
||||
diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
|
||||
|
||||
diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
|
||||
diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict)
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
@@ -806,89 +806,6 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_queries (int):
|
||||
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
output_dims: int = 1024,
|
||||
hidden_dims: int = 1280,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward # Lazy import to avoid circular import
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
||||
|
||||
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
||||
self.norm_out = nn.LayerNorm(output_dims)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
nn.LayerNorm(hidden_dims),
|
||||
nn.LayerNorm(hidden_dims),
|
||||
Attention(
|
||||
query_dim=hidden_dims,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
out_bias=False,
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(hidden_dims),
|
||||
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input Tensor.
|
||||
Returns:
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for ln0, ln1, attn, ff in self.layers:
|
||||
residual = latents
|
||||
|
||||
encoder_hidden_states = ln0(x)
|
||||
latents = ln1(latents)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
||||
latents = attn(latents, encoder_hidden_states) + residual
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -922,6 +839,65 @@ class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
return latents
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_queries (int):
|
||||
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
output_dims: int = 1024,
|
||||
hidden_dims: int = 1280,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
||||
|
||||
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
||||
self.norm_out = nn.LayerNorm(output_dims)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input Tensor.
|
||||
Returns:
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for block in self.layers:
|
||||
residual = latents
|
||||
latents = block(x, latents, residual)
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
"""FacePerceiverResampler of IP-Adapter Plus.
|
||||
|
||||
|
||||
@@ -146,42 +146,64 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = OrderedDict()
|
||||
keys = [k for k in image_projection.state_dict() if "layers." in k]
|
||||
print(keys)
|
||||
for k, v in image_projection.state_dict().items():
|
||||
if "2.to" in k:
|
||||
k = k.replace("2.to", "0.to")
|
||||
elif "3.0.weight" in k:
|
||||
k = k.replace("3.0.weight", "1.0.weight")
|
||||
elif "3.0.bias" in k:
|
||||
k = k.replace("3.0.bias", "1.0.bias")
|
||||
elif "3.0.weight" in k:
|
||||
k = k.replace("3.0.weight", "1.0.weight")
|
||||
elif "3.1.net.0.proj.weight" in k:
|
||||
k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
|
||||
elif "3.net.2.weight" in k:
|
||||
k = k.replace("3.net.2.weight", "1.3.weight")
|
||||
elif "layers.0.0" in k:
|
||||
k = k.replace("layers.0.0", "layers.0.0.norm1")
|
||||
elif "layers.0.1" in k:
|
||||
k = k.replace("layers.0.1", "layers.0.0.norm2")
|
||||
elif "layers.1.0" in k:
|
||||
k = k.replace("layers.1.0", "layers.1.0.norm1")
|
||||
elif "layers.1.1" in k:
|
||||
k = k.replace("layers.1.1", "layers.1.0.norm2")
|
||||
elif "layers.2.0" in k:
|
||||
k = k.replace("layers.2.0", "layers.2.0.norm1")
|
||||
elif "layers.2.1" in k:
|
||||
k = k.replace("layers.2.1", "layers.2.0.norm2")
|
||||
elif "layers.0.ln0" in k:
|
||||
k = k.replace("layers.0.ln0", "layers.0.0.norm1")
|
||||
elif "layers.0.ln1" in k:
|
||||
k = k.replace("layers.0.ln1", "layers.0.0.norm2")
|
||||
elif "layers.1.ln0" in k:
|
||||
k = k.replace("layers.1.ln0", "layers.1.0.norm1")
|
||||
elif "layers.1.ln1" in k:
|
||||
k = k.replace("layers.1.ln1", "layers.1.0.norm2")
|
||||
elif "layers.2.ln0" in k:
|
||||
k = k.replace("layers.2.ln0", "layers.2.0.norm1")
|
||||
elif "layers.2.ln1" in k:
|
||||
k = k.replace("layers.2.ln1", "layers.2.0.norm2")
|
||||
elif "layers.3.ln0" in k:
|
||||
k = k.replace("layers.3.ln0", "layers.3.0.norm1")
|
||||
elif "layers.3.ln1" in k:
|
||||
k = k.replace("layers.3.ln1", "layers.3.0.norm2")
|
||||
elif "to_q" in k:
|
||||
parts = k.split(".")
|
||||
parts[2] = "attn"
|
||||
k = ".".join(parts)
|
||||
elif "to_out.0" in k:
|
||||
parts = k.split(".")
|
||||
parts[2] = "attn"
|
||||
k = ".".join(parts)
|
||||
k = k.replace("to_out.0", "to_out")
|
||||
else:
|
||||
k = k.replace("0.ff.0", "0.1.0")
|
||||
k = k.replace("0.ff.1.net.0.proj", "0.1.1")
|
||||
k = k.replace("0.ff.1.net.2", "0.1.3")
|
||||
|
||||
if "norm_cross" in k:
|
||||
ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
|
||||
elif "layer_norm" in k:
|
||||
ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
|
||||
elif "to_k" in k:
|
||||
k = k.replace("1.ff.0", "1.1.0")
|
||||
k = k.replace("1.ff.1.net.0.proj", "1.1.1")
|
||||
k = k.replace("1.ff.1.net.2", "1.1.3")
|
||||
|
||||
k = k.replace("2.ff.0", "2.1.0")
|
||||
k = k.replace("2.ff.1.net.0.proj", "2.1.1")
|
||||
k = k.replace("2.ff.1.net.2", "2.1.3")
|
||||
|
||||
k = k.replace("3.ff.0", "3.1.0")
|
||||
k = k.replace("3.ff.1.net.0.proj", "3.1.1")
|
||||
k = k.replace("3.ff.1.net.2", "3.1.3")
|
||||
|
||||
# if "norm_cross" in k:
|
||||
# ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
|
||||
# elif "layer_norm" in k:
|
||||
# ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
|
||||
if "to_k" in k:
|
||||
parts = k.split(".")
|
||||
parts[2] = "attn"
|
||||
k = ".".join(parts)
|
||||
ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
|
||||
elif "to_v" in k:
|
||||
continue
|
||||
elif "to_out.0" in k:
|
||||
ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
|
||||
else:
|
||||
ip_image_projection_state_dict[k] = v
|
||||
|
||||
|
||||
Reference in New Issue
Block a user