1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2026-01-12 20:01:13 +05:30
parent 78233be7b4
commit 3a0efa38f5

View File

@@ -19,10 +19,9 @@ import inspect
import os
import re
import sys
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, get_args, get_origin, get_type_hints
from typing import Any, Callable, Dict, List, Union, get_args, get_origin
import httpx
import numpy as np
@@ -1820,84 +1819,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
@classmethod
def _get_signature_types(cls):
signature_types = {}
module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {}
localns = dict(vars(cls))
try:
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True)
except TypeError:
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns)
except Exception as exc:
logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc)
type_hints = {}
def _is_union(annotation: Any) -> bool:
origin = get_origin(annotation)
union_type = getattr(types, "UnionType", None)
if origin in (union_type):
return True
return union_type is not None and isinstance(annotation, union_type)
def _normalize_annotation(annotation: Any) -> tuple[type, ...]:
if annotation is inspect._empty:
return (inspect.Signature.empty,)
if annotation is None:
return (type(None),)
if annotation is Any:
return (Any,)
if inspect.isclass(annotation):
return (annotation,)
if _is_union(annotation):
collected: list[type] = []
for arg in get_args(annotation):
collected.extend(_normalize_annotation(arg))
# preserve order while removing duplicates
unique: list[type] = []
seen: set[type] = set()
for item in collected:
if item not in seen:
seen.add(item)
unique.append(item)
return tuple(unique)
origin = get_origin(annotation)
if origin is not None:
if getattr(origin, "__qualname__", "") == "Annotated":
args = get_args(annotation)
return _normalize_annotation(args[0]) if args else ()
if getattr(origin, "__qualname__", "") == "Literal":
return ()
if inspect.isclass(origin):
return (origin,)
return ()
for name, parameter in inspect.signature(cls.__init__).parameters.items():
if name == "self":
continue
annotation = type_hints.get(name, parameter.annotation)
if isinstance(annotation, str):
try:
annotation = eval(annotation, module_globals, localns) # noqa: S307
except Exception as exc: # noqa: BLE001
logger.debug(
"Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc
)
annotation = inspect._empty
normalized = _normalize_annotation(annotation)
if normalized:
signature_types[name] = normalized
elif annotation not in (inspect._empty, None, Any):
logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.")
for k, v in inspect.signature(cls.__init__).parameters.items():
if inspect.isclass(v.annotation):
signature_types[k] = (v.annotation,)
elif get_origin(v.annotation) == Union:
signature_types[k] = get_args(v.annotation)
elif get_origin(v.annotation) in [List, Dict, list, dict]:
signature_types[k] = (v.annotation,)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
return signature_types
@property