mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -19,10 +19,9 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Dict, List, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
@@ -1820,84 +1819,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
@classmethod
|
||||
def _get_signature_types(cls):
|
||||
signature_types = {}
|
||||
module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {}
|
||||
localns = dict(vars(cls))
|
||||
|
||||
try:
|
||||
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True)
|
||||
except TypeError:
|
||||
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc)
|
||||
type_hints = {}
|
||||
|
||||
def _is_union(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
union_type = getattr(types, "UnionType", None)
|
||||
if origin in (union_type):
|
||||
return True
|
||||
return union_type is not None and isinstance(annotation, union_type)
|
||||
|
||||
def _normalize_annotation(annotation: Any) -> tuple[type, ...]:
|
||||
if annotation is inspect._empty:
|
||||
return (inspect.Signature.empty,)
|
||||
|
||||
if annotation is None:
|
||||
return (type(None),)
|
||||
|
||||
if annotation is Any:
|
||||
return (Any,)
|
||||
|
||||
if inspect.isclass(annotation):
|
||||
return (annotation,)
|
||||
|
||||
if _is_union(annotation):
|
||||
collected: list[type] = []
|
||||
for arg in get_args(annotation):
|
||||
collected.extend(_normalize_annotation(arg))
|
||||
# preserve order while removing duplicates
|
||||
unique: list[type] = []
|
||||
seen: set[type] = set()
|
||||
for item in collected:
|
||||
if item not in seen:
|
||||
seen.add(item)
|
||||
unique.append(item)
|
||||
return tuple(unique)
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is not None:
|
||||
if getattr(origin, "__qualname__", "") == "Annotated":
|
||||
args = get_args(annotation)
|
||||
return _normalize_annotation(args[0]) if args else ()
|
||||
if getattr(origin, "__qualname__", "") == "Literal":
|
||||
return ()
|
||||
if inspect.isclass(origin):
|
||||
return (origin,)
|
||||
|
||||
return ()
|
||||
|
||||
for name, parameter in inspect.signature(cls.__init__).parameters.items():
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
annotation = type_hints.get(name, parameter.annotation)
|
||||
|
||||
if isinstance(annotation, str):
|
||||
try:
|
||||
annotation = eval(annotation, module_globals, localns) # noqa: S307
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug(
|
||||
"Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc
|
||||
)
|
||||
annotation = inspect._empty
|
||||
|
||||
normalized = _normalize_annotation(annotation)
|
||||
|
||||
if normalized:
|
||||
signature_types[name] = normalized
|
||||
elif annotation not in (inspect._empty, None, Any):
|
||||
logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.")
|
||||
|
||||
for k, v in inspect.signature(cls.__init__).parameters.items():
|
||||
if inspect.isclass(v.annotation):
|
||||
signature_types[k] = (v.annotation,)
|
||||
elif get_origin(v.annotation) == Union:
|
||||
signature_types[k] = get_args(v.annotation)
|
||||
elif get_origin(v.annotation) in [List, Dict, list, dict]:
|
||||
signature_types[k] = (v.annotation,)
|
||||
else:
|
||||
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
|
||||
return signature_types
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user