1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move IP Adapter Face ID to core (#7186)

* Switch to peft and multi proj layers

* Move Face ID loading and inference to core

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Fabio Rigano
2024-04-19 02:13:27 +02:00
committed by GitHub
parent e23c27e905
commit b5c8b555d7
10 changed files with 595 additions and 378 deletions

View File

@@ -30,7 +30,7 @@ from diffusers.models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model):
return ip_state_dict
def create_ip_adapter_faceid_state_dict(model):
# "ip_adapter" (cross-attention weights)
# no LoRA weights
ip_cross_attn_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
if cross_attention_dim is not None:
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
key_id += 2
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = IPAdapterFaceIDImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.0.weight": sd["ff.net.0.proj.weight"],
"proj.0.bias": sd["ff.net.0.proj.bias"],
"proj.2.weight": sd["ff.net.2.weight"],
"proj.2.bias": sd["ff.net.2.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True

View File

@@ -37,6 +37,7 @@ from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_flaky,
load_pt,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
@@ -306,6 +307,35 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_text_to_image_face_id(self):
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter-FaceID",
subfolder=None,
weight_name="ip-adapter-faceid_sd15.bin",
image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale(0.7)
inputs = self.get_dummy_inputs()
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
0
]
id_embeds = id_embeds.reshape((2, 1, 1, 512))
inputs["ip_adapter_image_embeds"] = [id_embeds]
inputs["ip_adapter_image"] = None
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953]
)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
@slow
@require_torch_gpu

View File

@@ -48,7 +48,10 @@ from ..models.autoencoders.test_models_vae import (
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -239,6 +242,9 @@ class IPAdapterTesterMixin:
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
@@ -416,6 +422,46 @@ class IPAdapterTesterMixin:
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin:
"""