diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2d1bb98904..b550de03b0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -19,9 +19,10 @@ import inspect import os import re import sys +import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin +from typing import Any, Callable, Optional, Union, get_args, get_origin import httpx import numpy as np @@ -1073,6 +1074,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ) # 10. Type checking init arguments + print(f"{expected_types.keys()=}") for kw, arg in init_kwargs.items(): # Too complex to validate with type annotation alone if "scheduler" in kw: @@ -1816,9 +1818,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): for k, v in inspect.signature(cls.__init__).parameters.items(): if inspect.isclass(v.annotation): signature_types[k] = (v.annotation,) - elif get_origin(v.annotation) == Union: + elif get_origin(v.annotation) in [Union, types.UnionType]: signature_types[k] = get_args(v.annotation) - elif get_origin(v.annotation) in [list, Dict, list, dict]: + elif get_origin(v.annotation) in [list, dict]: signature_types[k] = (v.annotation,) else: logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") diff --git a/src/diffusers/utils/source_code_parsing_utils.py b/src/diffusers/utils/source_code_parsing_utils.py index c69b40d11b..5f94711c21 100644 --- a/src/diffusers/utils/source_code_parsing_utils.py +++ b/src/diffusers/utils/source_code_parsing_utils.py @@ -12,7 +12,7 @@ class ReturnNameVisitor(ast.NodeVisitor): def visit_Return(self, node): # Check if the return value is a tuple. - if isinstance(node.value, ast.tuple): + if isinstance(node.value, ast.Tuple): for elt in node.value.elts: if isinstance(elt, ast.Name): self.return_names.append(elt.id)