From 1450c2ac4f384bbca65d6b7a132fa876b511b4e4 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Tue, 25 Feb 2025 09:51:15 +0000 Subject: [PATCH] Multi IP-Adapter for Flux pipelines (#10867) * Initial implementation of Flux multi IP-Adapter * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky * Changes for ipa image embeds * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky * Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky * make style && make quality * Updated ip_adapter test * Created typing_utils.py --------- Co-authored-by: hlky --- src/diffusers/loaders/ip_adapter.py | 49 ++++++---- src/diffusers/models/attention_processor.py | 15 +-- src/diffusers/models/embeddings.py | 5 + src/diffusers/pipelines/flux/pipeline_flux.py | 22 +++-- .../pipelines/pipeline_loading_utils.py | 75 +-------------- src/diffusers/pipelines/pipeline_utils.py | 4 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/typing_utils.py | 91 +++++++++++++++++++ tests/pipelines/test_pipelines_common.py | 41 +++++++++ 9 files changed, 193 insertions(+), 110 deletions(-) create mode 100644 src/diffusers/utils/typing_utils.py diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 7b691d1fe1..33144090cb 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -23,7 +23,9 @@ from safetensors import safe_open from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( USE_PEFT_BACKEND, + _get_detailed_type, _get_model_file, + _is_valid_type, is_accelerate_available, is_torch_version, is_transformers_available, @@ -577,29 +579,36 @@ class FluxIPAdapterMixin: pipeline.set_ip_adapter_scale(ip_strengths) ``` """ - transformer = self.transformer - if not isinstance(scale, list): - scale = [[scale] * transformer.config.num_layers] - elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float): - if len(scale) != transformer.config.num_layers: - raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.") + + scale_type = Union[int, float] + num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters + num_layers = self.transformer.config.num_layers + + # Single value for all layers of all IP-Adapters + if isinstance(scale, scale_type): + scale = [scale for _ in range(num_ip_adapters)] + # List of per-layer scales for a single IP-Adapter + elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1: scale = [scale] + # Invalid scale type + elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]): + raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.") - scale_configs = scale + if len(scale) != num_ip_adapters: + raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.") - key_id = 0 - for attn_name, attn_processor in transformer.attn_processors.items(): - if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)): - if len(scale_configs) != len(attn_processor.scale): - raise ValueError( - f"Cannot assign {len(scale_configs)} scale_configs to " - f"{len(attn_processor.scale)} IP-Adapter." - ) - elif len(scale_configs) == 1: - scale_configs = scale_configs * len(attn_processor.scale) - for i, scale_config in enumerate(scale_configs): - attn_processor.scale[i] = scale_config[key_id] - key_id += 1 + if any(len(s) != num_layers for s in scale if isinstance(s, list)): + invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers} + raise ValueError( + f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}." + ) + + # Scalars are transformed to lists with length num_layers + scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale] + + # Set scales. zip over scale_configs prevents going into single transformer layers + for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs): + attn_processor.scale = scale def unload_ip_adapter(self): """ diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c9f5e7c115..fe126c46df 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2780,9 +2780,8 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): # IP-adapter ip_query = hidden_states_query_proj - ip_attn_output = None - # for ip-adapter - # TODO: support for multiple adapters + ip_attn_output = torch.zeros_like(hidden_states) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip ): @@ -2793,12 +2792,14 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - ip_attn_output = F.scaled_dot_product_attention( + current_ip_hidden_states = F.scaled_dot_product_attention( ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_attn_output = scale * ip_attn_output - ip_attn_output = ip_attn_output.to(ip_query.dtype) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states return hidden_states, encoder_hidden_states, ip_attn_output else: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 390b752abe..04a0b273f1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2583,6 +2583,11 @@ class MultiIPAdapterImageProjection(nn.Module): super().__init__() self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + @property + def num_ip_adapters(self) -> int: + """Number of IP-Adapters loaded.""" + return len(self.image_projection_layers) + def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds = [] diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 9f4788a498..e49371c0d5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -405,23 +405,28 @@ class FluxPipeline( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): + for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - image_embeds.append(single_image_embeds[None, :]) else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): + for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -872,10 +877,13 @@ class FluxPipeline( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0e2cbb32d3..9a9afa198b 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -17,7 +17,7 @@ import os import re import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Union import requests import torch @@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): break if has_transformers_component and not is_transformers_version(">", "4.47.1"): raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") - - -def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: - """ - Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of - the correct type as well. - """ - if not isinstance(class_or_tuple, tuple): - class_or_tuple = (class_or_tuple,) - - # Unpack unions - unpacked_class_or_tuple = [] - for t in class_or_tuple: - if get_origin(t) is Union: - unpacked_class_or_tuple.extend(get_args(t)) - else: - unpacked_class_or_tuple.append(t) - class_or_tuple = tuple(unpacked_class_or_tuple) - - if Any in class_or_tuple: - return True - - obj_type = type(obj) - # Classes with obj's type - class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} - - # Singular types (e.g. int, ControlNet, ...) - # Untyped collections (e.g. List, but not List[int]) - elem_class_or_tuple = {get_args(t) for t in class_or_tuple} - if () in elem_class_or_tuple: - return True - # Typed lists or sets - elif obj_type in (list, set): - return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) - # Typed tuples - elif obj_type is tuple: - return any( - # Tuples with any length and single type (e.g. Tuple[int, ...]) - (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) - or - # Tuples with fixed length and any types (e.g. Tuple[int, str]) - (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) - for t in elem_class_or_tuple - ) - # Typed dicts - elif obj_type is dict: - return any( - all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) - for kt, vt in elem_class_or_tuple - ) - - else: - return False - - -def _get_detailed_type(obj: Any) -> Type: - """ - Gets a detailed type for an object, including nested types for collections. - """ - obj_type = type(obj) - - if obj_type in (list, set): - obj_origin_type = List if obj_type is list else Set - elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] - return obj_origin_type[elems_type] - elif obj_type is tuple: - return Tuple[tuple(_get_detailed_type(x) for x in obj)] - elif obj_type is dict: - keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] - values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] - return Dict[keys_type, values_type] - else: - return obj_type diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e112947c8d..1b306b1805 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -54,6 +54,8 @@ from ..utils import ( DEPRECATED_REVISION_ARGS, BaseOutput, PushToHubMixin, + _get_detailed_type, + _is_valid_type, is_accelerate_available, is_accelerate_version, is_torch_npu_available, @@ -78,12 +80,10 @@ from .pipeline_loading_utils import ( _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, - _get_detailed_type, _get_final_device_map, _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, - _is_valid_type, _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d82aded4c4..08b1713d0e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -123,6 +123,7 @@ from .state_dict_utils import ( convert_state_dict_to_peft, convert_unet_state_dict_to_peft, ) +from .typing_utils import _get_detailed_type, _is_valid_type logger = get_logger(__name__) diff --git a/src/diffusers/utils/typing_utils.py b/src/diffusers/utils/typing_utils.py new file mode 100644 index 0000000000..2b5b1a4f5a --- /dev/null +++ b/src/diffusers/utils/typing_utils.py @@ -0,0 +1,91 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Typing utilities: Utilities related to type checking and validation +""" + +from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin + + +def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + """ + Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of + the correct type as well. + """ + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + +def _get_detailed_type(obj: Any) -> Type: + """ + Gets a detailed type for an object, including nested types for collections. + """ + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(_get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] + return Dict[keys_type, values_type] + else: + return obj_type diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 33a7fd9f2b..a98de5c9ea 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -527,7 +527,9 @@ class FluxIPAdapterTesterMixin: The following scenarios are tested: - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. """ # Raising the tolerance for this test when it's run on a CPU because we # compare against static slices and that can be shaky (with a VVVV low probability). @@ -545,6 +547,7 @@ class FluxIPAdapterTesterMixin: else: output_without_adapter = expected_pipe_slice + # 1. Single IP-Adapter test cases adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer) pipe.transformer._load_ip_adapter_weights(adapter_state_dict) @@ -578,6 +581,44 @@ class FluxIPAdapterTesterMixin: max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" ) + # 2. Multi IP-Adapter test cases + adapter_state_dict_1 = create_flux_ip_adapter_state_dict(pipe.transformer) + adapter_state_dict_2 = create_flux_ip_adapter_state_dict(pipe.transformer) + pipe.transformer._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + self.assertLess( + max_diff_without_multi_adapter_scale, + expected_max_diff, + "Output without multi-ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_multi_adapter_scale, + 1e-2, + "Output with multi-ip-adapter scale must be different from normal inference", + ) + class PipelineLatentTesterMixin: """