diff --git a/src/diffusers/utils/typing_utils.py b/src/diffusers/utils/typing_utils.py index 00853608b4..71f2e38b6a 100644 --- a/src/diffusers/utils/typing_utils.py +++ b/src/diffusers/utils/typing_utils.py @@ -18,6 +18,12 @@ Typing utilities: Utilities related to type checking and validation from typing import Any, Set, Type, Union, get_args, get_origin +try: + from types import UnionType as _UnionType +except ImportError: # Python < 3.10 + _UnionType = None + + def _is_valid_type(obj: Any, class_or_tuple: Type | tuple[Type, ...]) -> bool: """ Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of @@ -29,7 +35,12 @@ def _is_valid_type(obj: Any, class_or_tuple: Type | tuple[Type, ...]) -> bool: # Unpack unions unpacked_class_or_tuple = [] for t in class_or_tuple: - if get_origin(t) is Union: + origin = get_origin(t) + is_union = origin is Union or (_UnionType is not None and origin is _UnionType) + # For PEP 604 unions (e.g. int | float), origin can be None but the object itself is a UnionType + if not is_union and _UnionType is not None and isinstance(t, _UnionType): + is_union = True + if is_union: unpacked_class_or_tuple.extend(get_args(t)) else: unpacked_class_or_tuple.append(t)