From 2571c000547da5782e5ccd0ce448adee2a221026 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 23 Apr 2025 19:43:34 +0200 Subject: [PATCH] move componentspec, configspec, input output param to utils --- .../pipelines/modular_pipeline_utils.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index fb6b83c7ee..0fec1db91e 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -13,14 +13,61 @@ # limitations under the License. import re -from typing import Any, Dict, List, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict if is_torch_available(): import torch +@dataclass +class ComponentSpec: + """Specification for a pipeline component.""" + name: str + # YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance + type_hint: Type + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor + repo: Optional[Union[str, List[str]]] = None + subfolder: Optional[str] = None + revision: Optional[str] = None + variant: Optional[str] = None + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + value: Any + description: Optional[str] = None + repo: Optional[Union[str, List[str]]] = None + +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + def format_inputs_short(inputs): """ Format input parameters into a string representation, with required params first followed by optional ones.