mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move IP Adapter Face ID to core (#7186)
* Switch to peft and multi proj layers * Move Face ID loading and inference to core --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -30,7 +30,7 @@ from diffusers.models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_ip_adapter_faceid_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
# no LoRA weights
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
|
||||
)
|
||||
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 2
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = IPAdapterFaceIDImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.0.weight": sd["ff.net.0.proj.weight"],
|
||||
"proj.0.bias": sd["ff.net.0.proj.bias"],
|
||||
"proj.2.weight": sd["ff.net.2.weight"],
|
||||
"proj.2.bias": sd["ff.net.2.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
train_kv = True
|
||||
train_q_out = True
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_flaky,
|
||||
load_pt,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -306,6 +307,35 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_text_to_image_face_id(self):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter-FaceID",
|
||||
subfolder=None,
|
||||
weight_name="ip-adapter-faceid_sd15.bin",
|
||||
image_encoder_folder=None,
|
||||
)
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
|
||||
0
|
||||
]
|
||||
id_embeds = id_embeds.reshape((2, 1, 1, 512))
|
||||
inputs["ip_adapter_image_embeds"] = [id_embeds]
|
||||
inputs["ip_adapter_image"] = None
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -48,7 +48,10 @@ from ..models.autoencoders.test_models_vae import (
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ..models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_faceid_state_dict,
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -239,6 +242,9 @@ class IPAdapterTesterMixin:
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
@@ -416,6 +422,46 @@ class IPAdapterTesterMixin:
|
||||
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
|
||||
|
||||
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs)[0]
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs)[0]
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
self.assertLess(
|
||||
max_diff_without_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user