1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-10-27 17:56:37 +05:30
parent ca5afaebca
commit 585c32b304
2 changed files with 6 additions and 4 deletions

View File

@@ -19,9 +19,10 @@ import inspect
import os
import re
import sys
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
from typing import Any, Callable, Optional, Union, get_args, get_origin
import httpx
import numpy as np
@@ -1073,6 +1074,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
# 10. Type checking init arguments
print(f"{expected_types.keys()=}")
for kw, arg in init_kwargs.items():
# Too complex to validate with type annotation alone
if "scheduler" in kw:
@@ -1816,9 +1818,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for k, v in inspect.signature(cls.__init__).parameters.items():
if inspect.isclass(v.annotation):
signature_types[k] = (v.annotation,)
elif get_origin(v.annotation) == Union:
elif get_origin(v.annotation) in [Union, types.UnionType]:
signature_types[k] = get_args(v.annotation)
elif get_origin(v.annotation) in [list, Dict, list, dict]:
elif get_origin(v.annotation) in [list, dict]:
signature_types[k] = (v.annotation,)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")

View File

@@ -12,7 +12,7 @@ class ReturnNameVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Check if the return value is a tuple.
if isinstance(node.value, ast.tuple):
if isinstance(node.value, ast.Tuple):
for elt in node.value.elts:
if isinstance(elt, ast.Name):
self.return_names.append(elt.id)