mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
move componentspec, configspec, input output param to utils
This commit is contained in:
@@ -13,14 +13,61 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
from ..configuration_utils import FrozenDict
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentSpec:
|
||||
"""Specification for a pipeline component."""
|
||||
name: str
|
||||
# YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance
|
||||
type_hint: Type
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor
|
||||
repo: Optional[Union[str, List[str]]] = None
|
||||
subfolder: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
variant: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ConfigSpec:
|
||||
"""Specification for a pipeline configuration parameter."""
|
||||
name: str
|
||||
value: Any
|
||||
description: Optional[str] = None
|
||||
repo: Optional[Union[str, List[str]]] = None
|
||||
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
"""Specification for an output parameter."""
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
"""
|
||||
Format input parameters into a string representation, with required params first followed by optional ones.
|
||||
|
||||
Reference in New Issue
Block a user