1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

reefactor pipeline/block states so that it can dynamically accept kwargs

This commit is contained in:
yiyixuxu
2025-05-06 09:58:44 +02:00
parent 43ac1ff7e7
commit dc4dbfe107
2 changed files with 127 additions and 30 deletions

View File

@@ -73,18 +73,72 @@ class PipelineState:
inputs: Dict[str, Any] = field(default_factory=dict)
intermediates: Dict[str, Any] = field(default_factory=dict)
input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict)
intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict)
def add_input(self, key: str, value: Any):
def add_input(self, key: str, value: Any, kwargs_type: str = None):
"""
Add an input to the pipeline state with optional metadata.
Args:
key (str): The key for the input
value (Any): The input value
kwargs_type (str): The kwargs_type to store with the input
"""
self.inputs[key] = value
if kwargs_type is not None:
if kwargs_type not in self.input_kwargs:
self.input_kwargs[kwargs_type] = [key]
else:
self.input_kwargs[kwargs_type].append(key)
def add_intermediate(self, key: str, value: Any):
def add_intermediate(self, key: str, value: Any, kwargs_type: str = None):
"""
Add an intermediate value to the pipeline state with optional metadata.
Args:
key (str): The key for the intermediate value
value (Any): The intermediate value
kwargs_type (str): The kwargs_type to store with the intermediate value
"""
self.intermediates[key] = value
if kwargs_type is not None:
if kwargs_type not in self.intermediate_kwargs:
self.intermediate_kwargs[kwargs_type] = [key]
else:
self.intermediate_kwargs[kwargs_type].append(key)
def get_input(self, key: str, default: Any = None) -> Any:
return self.inputs.get(key, default)
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
return {key: self.inputs.get(key, default) for key in keys}
def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
"""
Get all inputs with matching kwargs_type.
Args:
kwargs_type (str): The kwargs_type to filter by
Returns:
Dict[str, Any]: Dictionary of inputs with matching kwargs_type
"""
input_names = self.input_kwargs.get(kwargs_type, [])
return self.get_inputs(input_names)
def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
"""
Get all intermediates with matching kwargs_type.
Args:
kwargs_type (str): The kwargs_type to filter by
Returns:
Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
"""
intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
return self.get_intermediates(intermediate_names)
def get_intermediate(self, key: str, default: Any = None) -> Any:
return self.intermediates.get(key, default)
@@ -106,11 +160,17 @@ class PipelineState:
inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
# Format input_kwargs and intermediate_kwargs
input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
return (
f"PipelineState(\n"
f" inputs={{\n{inputs}\n }},\n"
f" intermediates={{\n{intermediates}\n }}\n"
f" intermediates={{\n{intermediates}\n }},\n"
f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
f")"
)
@@ -146,10 +206,16 @@ class BlockState:
# Handle dicts with tensor values
elif isinstance(v, dict):
if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()):
shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")}
return f"Dict of Tensors with shapes {shapes}"
return repr(v)
formatted_dict = {}
for k, val in v.items():
if hasattr(val, "shape") and hasattr(val, "dtype"):
formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})"
elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"):
shapes = [t.shape for t in val]
formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}"
else:
formatted_dict[k] = repr(val)
return formatted_dict
# Default case
return repr(v)
@@ -203,30 +269,34 @@ class ModularPipelineMixin:
self.loader = None
# Make a copy of the input kwargs
input_params = kwargs.copy()
passed_kwargs = kwargs.copy()
default_params = self.default_call_parameters
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.intermediates_inputs]
for name, default in default_params.items():
if name in input_params:
for expected_input_param in self.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
state.add_input(name, input_params.pop(name))
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, input_params[name])
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default)
state.add_input(name, default, kwargs_type)
for name in intermediates_inputs:
if name in input_params:
state.add_intermediate(name, input_params.pop(name))
for expected_intermediate_param in self.intermediates_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(input_params) > 0:
logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.")
if len(passed_kwargs) > 0:
logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
@@ -390,25 +460,50 @@ class PipelineBlock(ModularPipelineMixin):
# Check inputs
for input_param in self.inputs:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
data[input_param.name] = value
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediates_inputs:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
data[input_param.name] = value
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
if intermediates_kwargs:
for k, v in intermediates_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)
def add_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediates_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
state.add_intermediate(output_param.name, getattr(block_state, output_param.name))
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:

View File

@@ -244,11 +244,12 @@ class ConfigSpec:
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@@ -260,6 +261,7 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"