mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move IP Adapter Face ID to core (#7186)
* Switch to peft and multi proj layers * Move Face ID loading and inference to core --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma
|
||||
|
||||
### Face model
|
||||
|
||||
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
|
||||
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:
|
||||
|
||||
* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
|
||||
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
|
||||
|
||||
> [!TIP]
|
||||
>
|
||||
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
|
||||
Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.
|
||||
|
||||
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
|
||||
|
||||
@@ -411,6 +409,56 @@ image
|
||||
</div>
|
||||
</div>
|
||||
|
||||
To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
|
||||
|
||||
ref_images_embeds = []
|
||||
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
|
||||
faces = app.get(image)
|
||||
image = torch.from_numpy(faces[0].normed_embedding)
|
||||
ref_images_embeds.append(image.unsqueeze(0))
|
||||
ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
|
||||
neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
|
||||
id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(42)
|
||||
|
||||
images = pipeline(
|
||||
prompt="A photo of a girl",
|
||||
ip_adapter_image_embeds=[id_embeds],
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=1,
|
||||
generator=generator
|
||||
).images
|
||||
```
|
||||
|
||||
Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.
|
||||
|
||||
```py
|
||||
clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]
|
||||
|
||||
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
|
||||
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
|
||||
```
|
||||
|
||||
|
||||
### Multi IP-Adapter
|
||||
|
||||
More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.
|
||||
|
||||
@@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
|
||||
```
|
||||
|
||||
### IP-Adapter Face ID models
|
||||
|
||||
The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
|
||||
You need to install `insightface` and all its requirements to use these models.
|
||||
|
||||
<Tip warning={true}>
|
||||
As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
|
||||
```
|
||||
|
||||
If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.
|
||||
|
||||
```py
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
|
||||
```
|
||||
|
||||
@@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
|
||||
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
|
||||
You need to install `insightface` and all its requirements to use this model.
|
||||
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
|
||||
You have to disable PEFT BACKEND in order to load weights.
|
||||
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
|
||||
|
||||
```py
|
||||
import diffusers
|
||||
diffusers.utils.USE_PEFT_BACKEND = False
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
import cv2
|
||||
|
||||
@@ -26,7 +26,14 @@ from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.models.embeddings import MultiIPAdapterImageProjection
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -45,300 +52,6 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoRAIPAdapterAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
network_alpha (`int`, *optional*):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
the weight scale of LoRA.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||
The context length of the image features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim=None,
|
||||
rank=4,
|
||||
network_alpha=None,
|
||||
lora_scale=1.0,
|
||||
scale=1.0,
|
||||
num_tokens=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.rank = rank
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAIPAdapterAttnProcessor2_0(nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
rank (`int`, defaults to 4):
|
||||
The dimension of the LoRA update matrices.
|
||||
network_alpha (`int`, *optional*):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
the weight scale of LoRA.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||
The context length of the image features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim=None,
|
||||
rank=4,
|
||||
network_alpha=None,
|
||||
lora_scale=1.0,
|
||||
scale=1.0,
|
||||
num_tokens=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.rank = rank
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterFullImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
||||
super().__init__()
|
||||
@@ -615,17 +328,13 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
return image_projection
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict):
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
)
|
||||
|
||||
num_image_text_embeds = 4
|
||||
|
||||
self.unet.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
lora_dict = {}
|
||||
key_id = 0
|
||||
for name in self.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
||||
@@ -642,94 +351,99 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class()
|
||||
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
|
||||
attn_module = self.unet
|
||||
for n in name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features,
|
||||
out_features=attn_module.to_q.out_features,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features,
|
||||
out_features=attn_module.to_k.out_features,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features,
|
||||
out_features=attn_module.to_v.out_features,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=rank,
|
||||
)
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for k, module in attn_module.named_children():
|
||||
index = "."
|
||||
if not hasattr(module, "set_lora_layer"):
|
||||
index = ".0."
|
||||
module = module[0]
|
||||
lora_layer = getattr(module, "lora_layer")
|
||||
for lora_name, w in lora_layer.state_dict().items():
|
||||
value_dict.update(
|
||||
{
|
||||
f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][
|
||||
f"{key_id}.{k}_lora.{lora_name}"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
attn_module.load_state_dict(value_dict, strict=False)
|
||||
attn_module.to(dtype=self.dtype, device=self.device)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
|
||||
)
|
||||
key_id += 1
|
||||
else:
|
||||
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
|
||||
attn_processor_class = (
|
||||
LoRAIPAdapterAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else LoRAIPAdapterAttnProcessor
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
rank=rank,
|
||||
num_tokens=num_image_text_embeds,
|
||||
).to(dtype=self.dtype, device=self.device)
|
||||
|
||||
value_dict = {}
|
||||
for k, w in attn_procs[name].state_dict().items():
|
||||
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
|
||||
)
|
||||
lora_dict.update(
|
||||
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
key_id += 1
|
||||
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
self.load_lora_weights(lora_dict, adapter_name="faceid")
|
||||
self.set_adapters(["faceid"], adapter_weights=[1.0])
|
||||
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
|
||||
image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]
|
||||
|
||||
self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
|
||||
self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.unet.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = [scale]
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -1298,7 +1012,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
image_embeds = [image_embeds]
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
@@ -1319,7 +1033,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {}
|
||||
|
||||
# 6.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
|
||||
@@ -21,6 +21,7 @@ from safetensors import safe_open
|
||||
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
@@ -228,6 +229,18 @@ class IPAdapterMixin:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Sets the conditioning scale between text and image.
|
||||
|
||||
@@ -27,6 +27,8 @@ from torch import nn
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
IPAdapterFaceIDPlusImageProjection,
|
||||
IPAdapterFullImageProjection,
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
@@ -756,6 +758,90 @@ class UNet2DConditionLoadersMixin:
|
||||
diffusers_name = diffusers_name.replace("proj.3", "norm")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
elif "perceiver_resampler.proj_in.weight" in state_dict:
|
||||
# IP-Adapter Face ID Plus
|
||||
id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
|
||||
embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
|
||||
hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
|
||||
heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
with init_context():
|
||||
image_projection = IPAdapterFaceIDPlusImageProjection(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
id_embeddings_dim=id_embeddings_dim,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("perceiver_resampler.", "")
|
||||
diffusers_name = diffusers_name.replace("0.to", "attn.to")
|
||||
diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
|
||||
diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
|
||||
diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
|
||||
diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
|
||||
diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
|
||||
|
||||
if "norm1" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
|
||||
elif "norm2" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
|
||||
elif "to_kv" in diffusers_name:
|
||||
v_chunk = value.chunk(2, dim=0)
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_out" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
||||
elif "proj.0.weight" == diffusers_name:
|
||||
updated_state_dict["proj.net.0.proj.weight"] = value
|
||||
elif "proj.0.bias" == diffusers_name:
|
||||
updated_state_dict["proj.net.0.proj.bias"] = value
|
||||
elif "proj.2.weight" == diffusers_name:
|
||||
updated_state_dict["proj.net.2.weight"] = value
|
||||
elif "proj.2.bias" == diffusers_name:
|
||||
updated_state_dict["proj.net.2.bias"] = value
|
||||
else:
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
elif "norm.weight" in state_dict:
|
||||
# IP-Adapter Face ID
|
||||
id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
|
||||
id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
|
||||
multiplier = id_embeddings_dim_out // id_embeddings_dim_in
|
||||
norm_layer = "norm.weight"
|
||||
cross_attention_dim = state_dict[norm_layer].shape[0]
|
||||
num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
|
||||
|
||||
with init_context():
|
||||
image_projection = IPAdapterFaceIDImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=id_embeddings_dim_in,
|
||||
mult=multiplier,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["latents"].shape[1]
|
||||
@@ -847,6 +933,7 @@ class UNet2DConditionLoadersMixin:
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
@@ -859,6 +946,12 @@ class UNet2DConditionLoadersMixin:
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Full Face
|
||||
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
|
||||
elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Face ID Plus
|
||||
num_image_text_embeds += [4]
|
||||
elif "norm.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Face ID
|
||||
num_image_text_embeds += [4]
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
||||
@@ -910,6 +1003,59 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
def _load_ip_adapter_loras(self, state_dicts):
|
||||
lora_dicts = {}
|
||||
for key_id, name in enumerate(self.attn_processors.keys()):
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
|
||||
if i not in lora_dicts:
|
||||
lora_dicts[i] = {}
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_k_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_q_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_v_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.down.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
|
||||
)
|
||||
lora_dicts[i].update(
|
||||
{
|
||||
f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
|
||||
f"{key_id}.to_out_lora.up.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
return lora_dicts
|
||||
|
||||
|
||||
class FromOriginalUNetMixin:
|
||||
"""
|
||||
|
||||
@@ -472,6 +472,22 @@ class IPAdapterFullImageProjection(nn.Module):
|
||||
return self.norm(self.ff(image_embeds))
|
||||
|
||||
|
||||
class IPAdapterFaceIDImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor):
|
||||
x = self.ff(image_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
@@ -794,13 +810,14 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
----
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
Defaults to 16. num_queries (int):
|
||||
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
||||
of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
"""
|
||||
@@ -851,11 +868,8 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
----
|
||||
x (torch.Tensor): Input Tensor.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
@@ -875,6 +889,119 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(embed_dims)
|
||||
self.ln1 = nn.LayerNorm(embed_dims)
|
||||
self.attn = Attention(
|
||||
query_dim=embed_dims,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
out_bias=False,
|
||||
)
|
||||
self.ff = nn.Sequential(
|
||||
nn.LayerNorm(embed_dims),
|
||||
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x, latents, residual):
|
||||
encoder_hidden_states = self.ln0(x)
|
||||
latents = self.ln1(latents)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
||||
latents = self.attn(latents, encoder_hidden_states) + residual
|
||||
latents = self.ff(latents) + latents
|
||||
return latents
|
||||
|
||||
|
||||
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
"""FacePerceiverResampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
||||
that is the same
|
||||
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int):
|
||||
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
||||
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
||||
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels (for ID embeddings). Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
output_dims: int = 768,
|
||||
hidden_dims: int = 1280,
|
||||
id_embeddings_dim: int = 512,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_tokens: int = 4,
|
||||
num_queries: int = 8,
|
||||
ffn_ratio: float = 4,
|
||||
ffproj_ratio: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.embed_dim = embed_dims
|
||||
self.clip_embeds = None
|
||||
self.shortcut = False
|
||||
self.shortcut_scale = 1.0
|
||||
|
||||
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
||||
self.norm = nn.LayerNorm(embed_dims)
|
||||
|
||||
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
||||
|
||||
self.proj_out = nn.Linear(embed_dims, output_dims)
|
||||
self.norm_out = nn.LayerNorm(output_dims)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
||||
)
|
||||
|
||||
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
||||
Returns:
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
||||
id_embeds = self.proj(id_embeds)
|
||||
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
||||
id_embeds = self.norm(id_embeds)
|
||||
latents = id_embeds
|
||||
|
||||
clip_embeds = self.proj_in(self.clip_embeds)
|
||||
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
||||
|
||||
for block in self.layers:
|
||||
residual = latents
|
||||
latents = block(x, latents, residual)
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
out = self.norm_out(latents)
|
||||
if self.shortcut:
|
||||
out = id_embeds + self.shortcut_scale * out
|
||||
return out
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
super().__init__()
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_ip_adapter_faceid_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
# no LoRA weights
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
|
||||
)
|
||||
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 2
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = IPAdapterFaceIDImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.0.weight": sd["ff.net.0.proj.weight"],
|
||||
"proj.0.bias": sd["ff.net.0.proj.bias"],
|
||||
"proj.2.weight": sd["ff.net.2.weight"],
|
||||
"proj.2.bias": sd["ff.net.2.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
train_kv = True
|
||||
train_q_out = True
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_flaky,
|
||||
load_pt,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -306,6 +307,35 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_text_to_image_face_id(self):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter-FaceID",
|
||||
subfolder=None,
|
||||
weight_name="ip-adapter-faceid_sd15.bin",
|
||||
image_encoder_folder=None,
|
||||
)
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
|
||||
0
|
||||
]
|
||||
id_embeds = id_embeds.reshape((2, 1, 1, 512))
|
||||
inputs["ip_adapter_image_embeds"] = [id_embeds]
|
||||
inputs["ip_adapter_image"] = None
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -48,7 +48,10 @@ from ..models.autoencoders.test_models_vae import (
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ..models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_faceid_state_dict,
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -239,6 +242,9 @@ class IPAdapterTesterMixin:
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
@@ -416,6 +422,46 @@ class IPAdapterTesterMixin:
|
||||
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
|
||||
|
||||
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs)[0]
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs)[0]
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
self.assertLess(
|
||||
max_diff_without_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user