From ea77fdc4b4c50aaa7d5e0d619aa43457c277a603 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 6 Aug 2025 17:17:51 +0530 Subject: [PATCH] update --- .../modular_pipelines/modular_pipeline.py | 30 +++++++++++++++---- .../modular_pipeline_utils.py | 3 +- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 81cf519170..3f7436acfc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -264,6 +264,18 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """List of input parameters. Must be implemented by subclasses.""" return [] + def _get_required_inputs(self): + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + + return input_names + + @property + def required_inputs(self) -> List[InputParam]: + return self._get_required_inputs() + @property def intermediate_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" @@ -492,6 +504,17 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): def output_names(self) -> List[str]: return [output_param.name for output_param in self.outputs] + @property + def doc(self): + return make_doc_string( + self.inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + class AutoPipelineBlocks(ModularPipelineBlocks): """ @@ -743,7 +766,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks): def doc(self): return make_doc_string( self.inputs, - self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -2394,16 +2416,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs - intermediate_inputs = [inp.name for inp in self.blocks.inputs] for expected_input_param in self.blocks.inputs: name = expected_input_param.name default = expected_input_param.default kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: - if name not in intermediate_inputs: - state.set(name, passed_kwargs.pop(name), kwargs_type) - else: - state.set(name, passed_kwargs[name], kwargs_type) + state.set(name, passed_kwargs.pop(name), kwargs_type) elif name not in state.values: state.set(name, default, kwargs_type) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f2fc015e94..9118f13aa0 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines def make_doc_string( inputs, - intermediate_inputs, outputs, description="", class_name=None, @@ -664,7 +663,7 @@ def make_doc_string( output += configs_str + "\n\n" # Add inputs section - output += format_input_params(inputs + intermediate_inputs, indent_level=2) + output += format_input_params(inputs, indent_level=2) # Add outputs section output += "\n\n"