1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix ip adapter type checking.

This commit is contained in:
sayakpaul
2026-01-13 09:52:22 +05:30
parent 4cbe1aad54
commit b30be7d90f

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from pathlib import Path
from typing import List, Union
import torch
import torch.nn.functional as F
@@ -822,18 +823,18 @@ class FluxIPAdapterMixin:
```
"""
scale_type = int | float
scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers
# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# list of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, list[scale_type]) and num_ip_adapters == 1:
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, list[scale_type | list[scale_type]]):
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
if len(scale) != num_ip_adapters: