1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move IP Adapter Face ID to core (#7186)

* Switch to peft and multi proj layers

* Move Face ID loading and inference to core

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Fabio Rigano
2024-04-19 02:13:27 +02:00
committed by GitHub
parent e23c27e905
commit b5c8b555d7
10 changed files with 595 additions and 378 deletions

View File

@@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
You have to disable PEFT BACKEND in order to load weights.
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
```py
import diffusers
diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2

View File

@@ -26,7 +26,14 @@ from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import MultiIPAdapterImageProjection
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -45,300 +52,6 @@ from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LoRAIPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
lora_scale (`float`, defaults to 1.0):
the weight scale of LoRA.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(
self,
hidden_size,
cross_attention_dim=None,
rank=4,
network_alpha=None,
lora_scale=1.0,
scale=1.0,
num_tokens=4,
):
super().__init__()
self.rank = rank
self.lora_scale = lora_scale
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LoRAIPAdapterAttnProcessor2_0(nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
lora_scale (`float`, defaults to 1.0):
the weight scale of LoRA.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(
self,
hidden_size,
cross_attention_dim=None,
rank=4,
network_alpha=None,
lora_scale=1.0,
scale=1.0,
num_tokens=4,
):
super().__init__()
self.rank = rank
self.lora_scale = lora_scale
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
@@ -615,17 +328,13 @@ class IPAdapterFaceIDStableDiffusionPipeline(
return image_projection
def _load_ip_adapter_weights(self, state_dict):
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
)
num_image_text_embeds = 4
self.unet.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
lora_dict = {}
key_id = 0
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
@@ -642,94 +351,99 @@ class IPAdapterFaceIDStableDiffusionPipeline(
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
attn_module = self.unet
for n in name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=rank,
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=rank,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=rank,
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=rank,
)
)
value_dict = {}
for k, module in attn_module.named_children():
index = "."
if not hasattr(module, "set_lora_layer"):
index = ".0."
module = module[0]
lora_layer = getattr(module, "lora_layer")
for lora_name, w in lora_layer.state_dict().items():
value_dict.update(
{
f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][
f"{key_id}.{k}_lora.{lora_name}"
]
}
)
attn_module.load_state_dict(value_dict, strict=False)
attn_module.to(dtype=self.dtype, device=self.device)
lora_dict.update(
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
)
lora_dict.update(
{
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.down.weight"
]
}
)
lora_dict.update(
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
)
key_id += 1
else:
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
attn_processor_class = (
LoRAIPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAIPAdapterAttnProcessor
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
rank=rank,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
lora_dict.update(
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
)
lora_dict.update(
{
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.down.weight"
]
}
)
lora_dict.update(
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
)
value_dict = {}
value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 1
self.unet.set_attn_processor(attn_procs)
self.load_lora_weights(lora_dict, adapter_name="faceid")
self.set_adapters(["faceid"], adapter_weights=[1.0])
# convert IP-Adapter Image Projection layers to diffusers
image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]
self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.unet.config.encoder_hid_dim_type = "ip_image_proj"
def set_ip_adapter_scale(self, scale):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = [scale]
def _encode_prompt(
self,
@@ -1298,7 +1012,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
negative_image_embeds = torch.zeros_like(image_embeds)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = [image_embeds]
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -1319,7 +1033,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {}
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None