1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-08-06 17:17:51 +05:30
parent 255c5742aa
commit ea77fdc4b4
2 changed files with 25 additions and 8 deletions

View File

@@ -264,6 +264,18 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""List of input parameters. Must be implemented by subclasses."""
return []
def _get_required_inputs(self):
input_names = []
for input_param in self.inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@property
def required_inputs(self) -> List[InputParam]:
return self._get_required_inputs()
@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
@@ -492,6 +504,17 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
@property
def doc(self):
return make_doc_string(
self.inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs,
)
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
@@ -743,7 +766,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -2394,16 +2416,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediate_inputs = [inp.name for inp in self.blocks.inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediate_inputs:
state.set(name, passed_kwargs.pop(name), kwargs_type)
else:
state.set(name, passed_kwargs[name], kwargs_type)
state.set(name, passed_kwargs.pop(name), kwargs_type)
elif name not in state.values:
state.set(name, default, kwargs_type)

View File

@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string(
inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
@@ -664,7 +663,7 @@ def make_doc_string(
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
output += format_input_params(inputs, indent_level=2)
# Add outputs section
output += "\n\n"