mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
reefactor pipeline/block states so that it can dynamically accept kwargs
This commit is contained in:
@@ -73,18 +73,72 @@ class PipelineState:
|
||||
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
intermediates: Dict[str, Any] = field(default_factory=dict)
|
||||
input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict)
|
||||
intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def add_input(self, key: str, value: Any):
|
||||
def add_input(self, key: str, value: Any, kwargs_type: str = None):
|
||||
"""
|
||||
Add an input to the pipeline state with optional metadata.
|
||||
|
||||
Args:
|
||||
key (str): The key for the input
|
||||
value (Any): The input value
|
||||
kwargs_type (str): The kwargs_type to store with the input
|
||||
"""
|
||||
self.inputs[key] = value
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.input_kwargs:
|
||||
self.input_kwargs[kwargs_type] = [key]
|
||||
else:
|
||||
self.input_kwargs[kwargs_type].append(key)
|
||||
|
||||
def add_intermediate(self, key: str, value: Any):
|
||||
def add_intermediate(self, key: str, value: Any, kwargs_type: str = None):
|
||||
"""
|
||||
Add an intermediate value to the pipeline state with optional metadata.
|
||||
|
||||
Args:
|
||||
key (str): The key for the intermediate value
|
||||
value (Any): The intermediate value
|
||||
kwargs_type (str): The kwargs_type to store with the intermediate value
|
||||
"""
|
||||
self.intermediates[key] = value
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.intermediate_kwargs:
|
||||
self.intermediate_kwargs[kwargs_type] = [key]
|
||||
else:
|
||||
self.intermediate_kwargs[kwargs_type].append(key)
|
||||
|
||||
def get_input(self, key: str, default: Any = None) -> Any:
|
||||
return self.inputs.get(key, default)
|
||||
|
||||
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
return {key: self.inputs.get(key, default) for key in keys}
|
||||
|
||||
def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all inputs with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of inputs with matching kwargs_type
|
||||
"""
|
||||
input_names = self.input_kwargs.get(kwargs_type, [])
|
||||
return self.get_inputs(input_names)
|
||||
|
||||
def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all intermediates with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
|
||||
"""
|
||||
intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
|
||||
return self.get_intermediates(intermediate_names)
|
||||
|
||||
def get_intermediate(self, key: str, default: Any = None) -> Any:
|
||||
return self.intermediates.get(key, default)
|
||||
@@ -106,11 +160,17 @@ class PipelineState:
|
||||
|
||||
inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
|
||||
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
|
||||
|
||||
# Format input_kwargs and intermediate_kwargs
|
||||
input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
|
||||
intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
|
||||
|
||||
return (
|
||||
f"PipelineState(\n"
|
||||
f" inputs={{\n{inputs}\n }},\n"
|
||||
f" intermediates={{\n{intermediates}\n }}\n"
|
||||
f" intermediates={{\n{intermediates}\n }},\n"
|
||||
f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
|
||||
f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -146,10 +206,16 @@ class BlockState:
|
||||
|
||||
# Handle dicts with tensor values
|
||||
elif isinstance(v, dict):
|
||||
if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()):
|
||||
shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")}
|
||||
return f"Dict of Tensors with shapes {shapes}"
|
||||
return repr(v)
|
||||
formatted_dict = {}
|
||||
for k, val in v.items():
|
||||
if hasattr(val, "shape") and hasattr(val, "dtype"):
|
||||
formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})"
|
||||
elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"):
|
||||
shapes = [t.shape for t in val]
|
||||
formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}"
|
||||
else:
|
||||
formatted_dict[k] = repr(val)
|
||||
return formatted_dict
|
||||
|
||||
# Default case
|
||||
return repr(v)
|
||||
@@ -203,30 +269,34 @@ class ModularPipelineMixin:
|
||||
self.loader = None
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
input_params = kwargs.copy()
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
default_params = self.default_call_parameters
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
|
||||
for name, default in default_params.items():
|
||||
if name in input_params:
|
||||
for expected_input_param in self.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
state.add_input(name, input_params.pop(name))
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, input_params[name])
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default)
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for name in intermediates_inputs:
|
||||
if name in input_params:
|
||||
state.add_intermediate(name, input_params.pop(name))
|
||||
for expected_intermediate_param in self.intermediates_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(input_params) > 0:
|
||||
logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.")
|
||||
if len(passed_kwargs) > 0:
|
||||
logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
@@ -390,25 +460,50 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
|
||||
# Check inputs
|
||||
for input_param in self.inputs:
|
||||
value = state.get_input(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
data[input_param.name] = value
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
# Check intermediates
|
||||
for input_param in self.intermediates_inputs:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
||||
data[input_param.name] = value
|
||||
|
||||
if input_param.name:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
data[input_param.name] = value
|
||||
elif input_param.kwargs_type:
|
||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
if intermediates_kwargs:
|
||||
for k, v in intermediates_kwargs.items():
|
||||
if v is not None:
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
data[input_param.kwargs_type][k] = v
|
||||
return BlockState(**data)
|
||||
|
||||
def add_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediates_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
state.add_intermediate(output_param.name, getattr(block_state, output_param.name))
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
|
||||
@@ -244,11 +244,12 @@ class ConfigSpec:
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
name: str
|
||||
name: str = None
|
||||
type_hint: Any = None
|
||||
default: Any = None
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
@@ -260,6 +261,7 @@ class OutputParam:
|
||||
name: str
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
|
||||
|
||||
Reference in New Issue
Block a user