mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -19,9 +19,10 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
||||
from typing import Any, Callable, Optional, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
@@ -1073,6 +1074,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
# 10. Type checking init arguments
|
||||
print(f"{expected_types.keys()=}")
|
||||
for kw, arg in init_kwargs.items():
|
||||
# Too complex to validate with type annotation alone
|
||||
if "scheduler" in kw:
|
||||
@@ -1816,9 +1818,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
for k, v in inspect.signature(cls.__init__).parameters.items():
|
||||
if inspect.isclass(v.annotation):
|
||||
signature_types[k] = (v.annotation,)
|
||||
elif get_origin(v.annotation) == Union:
|
||||
elif get_origin(v.annotation) in [Union, types.UnionType]:
|
||||
signature_types[k] = get_args(v.annotation)
|
||||
elif get_origin(v.annotation) in [list, Dict, list, dict]:
|
||||
elif get_origin(v.annotation) in [list, dict]:
|
||||
signature_types[k] = (v.annotation,)
|
||||
else:
|
||||
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
|
||||
|
||||
@@ -12,7 +12,7 @@ class ReturnNameVisitor(ast.NodeVisitor):
|
||||
|
||||
def visit_Return(self, node):
|
||||
# Check if the return value is a tuple.
|
||||
if isinstance(node.value, ast.tuple):
|
||||
if isinstance(node.value, ast.Tuple):
|
||||
for elt in node.value.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
self.return_names.append(elt.id)
|
||||
|
||||
Reference in New Issue
Block a user