mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Multi IP-Adapter for Flux pipelines (#10867)
* Initial implementation of Flux multi IP-Adapter * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <hlky@hlky.ac> * Changes for ipa image embeds * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky <hlky@hlky.ac> * make style && make quality * Updated ip_adapter test * Created typing_utils.py --------- Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -23,7 +23,9 @@ from safetensors import safe_open
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_detailed_type,
|
||||
_get_model_file,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
@@ -577,29 +579,36 @@ class FluxIPAdapterMixin:
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
transformer = self.transformer
|
||||
if not isinstance(scale, list):
|
||||
scale = [[scale] * transformer.config.num_layers]
|
||||
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
|
||||
if len(scale) != transformer.config.num_layers:
|
||||
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
|
||||
|
||||
scale_type = Union[int, float]
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
scale_configs = scale
|
||||
if len(scale) != num_ip_adapters:
|
||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||
|
||||
key_id = 0
|
||||
for attn_name, attn_processor in transformer.attn_processors.items():
|
||||
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
attn_processor.scale[i] = scale_config[key_id]
|
||||
key_id += 1
|
||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||
raise ValueError(
|
||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||
)
|
||||
|
||||
# Scalars are transformed to lists with length num_layers
|
||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||
|
||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
|
||||
@@ -2780,9 +2780,8 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = None
|
||||
# for ip-adapter
|
||||
# TODO: support for multiple adapters
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
@@ -2793,12 +2792,14 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_attn_output = F.scaled_dot_product_attention(
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_attn_output = scale * ip_attn_output
|
||||
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
|
||||
@@ -2583,6 +2583,11 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
super().__init__()
|
||||
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
||||
|
||||
@property
|
||||
def num_ip_adapters(self) -> int:
|
||||
"""Number of IP-Adapters loaded."""
|
||||
return len(self.image_projection_layers)
|
||||
|
||||
def forward(self, image_embeds: List[torch.Tensor]):
|
||||
projected_image_embeds = []
|
||||
|
||||
|
||||
@@ -405,23 +405,28 @@ class FluxPipeline(
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
@@ -872,10 +877,13 @@ class FluxPipeline(
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
|
||||
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
|
||||
"""
|
||||
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
|
||||
the correct type as well.
|
||||
"""
|
||||
if not isinstance(class_or_tuple, tuple):
|
||||
class_or_tuple = (class_or_tuple,)
|
||||
|
||||
# Unpack unions
|
||||
unpacked_class_or_tuple = []
|
||||
for t in class_or_tuple:
|
||||
if get_origin(t) is Union:
|
||||
unpacked_class_or_tuple.extend(get_args(t))
|
||||
else:
|
||||
unpacked_class_or_tuple.append(t)
|
||||
class_or_tuple = tuple(unpacked_class_or_tuple)
|
||||
|
||||
if Any in class_or_tuple:
|
||||
return True
|
||||
|
||||
obj_type = type(obj)
|
||||
# Classes with obj's type
|
||||
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
|
||||
|
||||
# Singular types (e.g. int, ControlNet, ...)
|
||||
# Untyped collections (e.g. List, but not List[int])
|
||||
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
|
||||
if () in elem_class_or_tuple:
|
||||
return True
|
||||
# Typed lists or sets
|
||||
elif obj_type in (list, set):
|
||||
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
|
||||
# Typed tuples
|
||||
elif obj_type is tuple:
|
||||
return any(
|
||||
# Tuples with any length and single type (e.g. Tuple[int, ...])
|
||||
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
|
||||
or
|
||||
# Tuples with fixed length and any types (e.g. Tuple[int, str])
|
||||
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
|
||||
for t in elem_class_or_tuple
|
||||
)
|
||||
# Typed dicts
|
||||
elif obj_type is dict:
|
||||
return any(
|
||||
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
|
||||
for kt, vt in elem_class_or_tuple
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _get_detailed_type(obj: Any) -> Type:
|
||||
"""
|
||||
Gets a detailed type for an object, including nested types for collections.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
|
||||
if obj_type in (list, set):
|
||||
obj_origin_type = List if obj_type is list else Set
|
||||
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
|
||||
return obj_origin_type[elems_type]
|
||||
elif obj_type is tuple:
|
||||
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
|
||||
elif obj_type is dict:
|
||||
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
|
||||
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
|
||||
return Dict[keys_type, values_type]
|
||||
else:
|
||||
return obj_type
|
||||
|
||||
@@ -54,6 +54,8 @@ from ..utils import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
BaseOutput,
|
||||
PushToHubMixin,
|
||||
_get_detailed_type,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_torch_npu_available,
|
||||
@@ -78,12 +80,10 @@ from .pipeline_loading_utils import (
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_components_and_folders,
|
||||
_get_custom_pipeline_class,
|
||||
_get_detailed_type,
|
||||
_get_final_device_map,
|
||||
_get_ignore_patterns,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_is_valid_type,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
|
||||
@@ -123,6 +123,7 @@ from .state_dict_utils import (
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
)
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
91
src/diffusers/utils/typing_utils.py
Normal file
91
src/diffusers/utils/typing_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Typing utilities: Utilities related to type checking and validation
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin
|
||||
|
||||
|
||||
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
|
||||
"""
|
||||
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
|
||||
the correct type as well.
|
||||
"""
|
||||
if not isinstance(class_or_tuple, tuple):
|
||||
class_or_tuple = (class_or_tuple,)
|
||||
|
||||
# Unpack unions
|
||||
unpacked_class_or_tuple = []
|
||||
for t in class_or_tuple:
|
||||
if get_origin(t) is Union:
|
||||
unpacked_class_or_tuple.extend(get_args(t))
|
||||
else:
|
||||
unpacked_class_or_tuple.append(t)
|
||||
class_or_tuple = tuple(unpacked_class_or_tuple)
|
||||
|
||||
if Any in class_or_tuple:
|
||||
return True
|
||||
|
||||
obj_type = type(obj)
|
||||
# Classes with obj's type
|
||||
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
|
||||
|
||||
# Singular types (e.g. int, ControlNet, ...)
|
||||
# Untyped collections (e.g. List, but not List[int])
|
||||
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
|
||||
if () in elem_class_or_tuple:
|
||||
return True
|
||||
# Typed lists or sets
|
||||
elif obj_type in (list, set):
|
||||
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
|
||||
# Typed tuples
|
||||
elif obj_type is tuple:
|
||||
return any(
|
||||
# Tuples with any length and single type (e.g. Tuple[int, ...])
|
||||
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
|
||||
or
|
||||
# Tuples with fixed length and any types (e.g. Tuple[int, str])
|
||||
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
|
||||
for t in elem_class_or_tuple
|
||||
)
|
||||
# Typed dicts
|
||||
elif obj_type is dict:
|
||||
return any(
|
||||
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
|
||||
for kt, vt in elem_class_or_tuple
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _get_detailed_type(obj: Any) -> Type:
|
||||
"""
|
||||
Gets a detailed type for an object, including nested types for collections.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
|
||||
if obj_type in (list, set):
|
||||
obj_origin_type = List if obj_type is list else Set
|
||||
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
|
||||
return obj_origin_type[elems_type]
|
||||
elif obj_type is tuple:
|
||||
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
|
||||
elif obj_type is dict:
|
||||
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
|
||||
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
|
||||
return Dict[keys_type, values_type]
|
||||
else:
|
||||
return obj_type
|
||||
@@ -527,7 +527,9 @@ class FluxIPAdapterTesterMixin:
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
@@ -545,6 +547,7 @@ class FluxIPAdapterTesterMixin:
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
|
||||
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
@@ -578,6 +581,44 @@ class FluxIPAdapterTesterMixin:
|
||||
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_flux_ip_adapter_state_dict(pipe.transformer)
|
||||
adapter_state_dict_2 = create_flux_ip_adapter_state_dict(pipe.transformer)
|
||||
pipe.transformer._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
|
||||
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
|
||||
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
self.assertLess(
|
||||
max_diff_without_multi_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without multi-ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_multi_adapter_scale,
|
||||
1e-2,
|
||||
"Output with multi-ip-adapter scale must be different from normal inference",
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user