1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-10-27 20:43:30 +05:30
parent 19fe63170c
commit 3a00e23f5a

View File

@@ -22,7 +22,7 @@ import sys
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Optional, Union, get_args, get_origin
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
import httpx
import numpy as np
@@ -1815,15 +1815,78 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
@classmethod
def _get_signature_types(cls):
signature_types = {}
for k, v in inspect.signature(cls.__init__).parameters.items():
if inspect.isclass(v.annotation):
signature_types[k] = (v.annotation,)
elif get_origin(v.annotation) in [Union, types.UnionType]:
signature_types[k] = get_args(v.annotation)
elif get_origin(v.annotation) in [list, dict]:
signature_types[k] = (v.annotation,)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {}
localns = dict(vars(cls))
try:
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True)
except TypeError:
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns)
except Exception as exc:
logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc)
type_hints = {}
def _is_union(annotation: Any) -> bool:
origin = get_origin(annotation)
union_type = getattr(types, "UnionType", None)
if origin in (Union, union_type):
return True
return union_type is not None and isinstance(annotation, union_type)
def _normalize_annotation(annotation: Any) -> tuple[type, ...]:
if annotation in (inspect._empty, None) or annotation is Any:
return ()
if inspect.isclass(annotation):
return (annotation,)
if _is_union(annotation):
collected: list[type] = []
for arg in get_args(annotation):
collected.extend(_normalize_annotation(arg))
# preserve order while removing duplicates
unique: list[type] = []
seen: set[type] = set()
for item in collected:
if item not in seen:
seen.add(item)
unique.append(item)
return tuple(unique)
origin = get_origin(annotation)
if origin is not None:
if getattr(origin, "__qualname__", "") == "Annotated":
args = get_args(annotation)
return _normalize_annotation(args[0]) if args else ()
if getattr(origin, "__qualname__", "") == "Literal":
return ()
if inspect.isclass(origin):
return (origin,)
return ()
for name, parameter in inspect.signature(cls.__init__).parameters.items():
if name == "self":
continue
annotation = type_hints.get(name, parameter.annotation)
if isinstance(annotation, str):
try:
annotation = eval(annotation, module_globals, localns) # noqa: S307
except Exception as exc: # noqa: BLE001
logger.debug(
"Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc
)
annotation = inspect._empty
normalized = _normalize_annotation(annotation)
if normalized:
signature_types[name] = normalized
elif annotation not in (inspect._empty, None, Any):
logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.")
return signature_types
@property