1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2026-01-13 11:52:04 +05:30
parent 987412b252
commit 463367d31d

View File

@@ -1829,18 +1829,34 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if v.annotation != inspect.Parameter.empty:
type_hints[k] = v.annotation
for k, annotation in type_hints.items():
if inspect.isclass(annotation):
signature_types[k] = (annotation,)
elif get_origin(annotation) == Union:
signature_types[k] = get_args(annotation)
elif isinstance(annotation, types.UnionType):
# Handle PEP 604 union syntax (X | Y) introduced in Python 3.10+
signature_types[k] = get_args(annotation)
elif get_origin(annotation) in [List, Dict, list, dict]:
signature_types[k] = (annotation,)
# Get all parameters from the signature to ensure we don't miss any
all_params = inspect.signature(cls.__init__).parameters
for param_name, param in all_params.items():
# Skip 'self' parameter
if param_name == "self":
continue
# If we have type hints, use them
if param_name in type_hints:
annotation = type_hints[param_name]
if inspect.isclass(annotation):
signature_types[param_name] = (annotation,)
elif get_origin(annotation) == Union:
signature_types[param_name] = get_args(annotation)
elif isinstance(annotation, types.UnionType):
# Handle PEP 604 union syntax (X | Y) introduced in Python 3.10+
signature_types[param_name] = get_args(annotation)
elif get_origin(annotation) in [List, Dict, list, dict]:
signature_types[param_name] = (annotation,)
else:
logger.warning(f"cannot get type annotation for Parameter {param_name} of {cls}.")
# Still add it with empty signature so it's in expected_types
signature_types[param_name] = (inspect.Signature.empty,)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
# No type annotation found - add with empty signature
signature_types[param_name] = (inspect.Signature.empty,)
return signature_types
@property