diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index c994b91ba8..1733ad6d4e 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -73,18 +73,72 @@ class PipelineState: inputs: Dict[str, Any] = field(default_factory=dict) intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - def add_input(self, key: str, value: Any): + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any): + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: return self.inputs.get(key, default) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) def get_intermediate(self, key: str, default: Any = None) -> Any: return self.intermediates.get(key, default) @@ -106,11 +160,17 @@ class PipelineState: inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) return ( f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" f")" ) @@ -146,10 +206,16 @@ class BlockState: # Handle dicts with tensor values elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict # Default case return repr(v) @@ -203,30 +269,34 @@ class ModularPipelineMixin: self.loader = None # Make a copy of the input kwargs - input_params = kwargs.copy() + passed_kwargs = kwargs.copy() - default_params = self.default_call_parameters # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) + state.add_input(name, passed_kwargs.pop(name), kwargs_type) else: - state.add_input(name, input_params[name]) + state.add_input(name, passed_kwargs[name], kwargs_type) elif name not in state.inputs: - state.add_input(name, default) + state.add_input(name, default, kwargs_type) - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + if len(passed_kwargs) > 0: + logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") # Run the pipeline with torch.no_grad(): try: @@ -390,25 +460,50 @@ class PipelineBlock(ModularPipelineMixin): # Check inputs for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v # Check intermediates for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v return BlockState(**data) def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index c8064a5215..f300f259f9 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -244,11 +244,12 @@ class ConfigSpec: @dataclass class InputParam: """Specification for an input parameter.""" - name: str + name: str = None type_hint: Any = None default: Any = None required: bool = False description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @@ -260,6 +261,7 @@ class OutputParam: name: str type_hint: Any = None description: str = "" + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"