From 3a0efa38f51c72cbf99556f0b3170ea991ff078b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Jan 2026 20:01:13 +0530 Subject: [PATCH] up --- src/diffusers/pipelines/pipeline_utils.py | 90 +++-------------------- 1 file changed, 10 insertions(+), 80 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 10d1a06f25..d4010ddb71 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -19,10 +19,9 @@ import inspect import os import re import sys -import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Union, get_args, get_origin import httpx import numpy as np @@ -1820,84 +1819,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): @classmethod def _get_signature_types(cls): signature_types = {} - module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {} - localns = dict(vars(cls)) - - try: - type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True) - except TypeError: - type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns) - except Exception as exc: - logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc) - type_hints = {} - - def _is_union(annotation: Any) -> bool: - origin = get_origin(annotation) - union_type = getattr(types, "UnionType", None) - if origin in (union_type): - return True - return union_type is not None and isinstance(annotation, union_type) - - def _normalize_annotation(annotation: Any) -> tuple[type, ...]: - if annotation is inspect._empty: - return (inspect.Signature.empty,) - - if annotation is None: - return (type(None),) - - if annotation is Any: - return (Any,) - - if inspect.isclass(annotation): - return (annotation,) - - if _is_union(annotation): - collected: list[type] = [] - for arg in get_args(annotation): - collected.extend(_normalize_annotation(arg)) - # preserve order while removing duplicates - unique: list[type] = [] - seen: set[type] = set() - for item in collected: - if item not in seen: - seen.add(item) - unique.append(item) - return tuple(unique) - - origin = get_origin(annotation) - if origin is not None: - if getattr(origin, "__qualname__", "") == "Annotated": - args = get_args(annotation) - return _normalize_annotation(args[0]) if args else () - if getattr(origin, "__qualname__", "") == "Literal": - return () - if inspect.isclass(origin): - return (origin,) - - return () - - for name, parameter in inspect.signature(cls.__init__).parameters.items(): - if name == "self": - continue - - annotation = type_hints.get(name, parameter.annotation) - - if isinstance(annotation, str): - try: - annotation = eval(annotation, module_globals, localns) # noqa: S307 - except Exception as exc: # noqa: BLE001 - logger.debug( - "Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc - ) - annotation = inspect._empty - - normalized = _normalize_annotation(annotation) - - if normalized: - signature_types[name] = normalized - elif annotation not in (inspect._empty, None, Any): - logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.") - + for k, v in inspect.signature(cls.__init__).parameters.items(): + if inspect.isclass(v.annotation): + signature_types[k] = (v.annotation,) + elif get_origin(v.annotation) == Union: + signature_types[k] = get_args(v.annotation) + elif get_origin(v.annotation) in [List, Dict, list, dict]: + signature_types[k] = (v.annotation,) + else: + logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") return signature_types @property