mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -264,6 +264,18 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_required_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[InputParam]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
@@ -492,6 +504,17 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -743,7 +766,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -2394,16 +2416,12 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediate_inputs:
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.set(name, passed_kwargs[name], kwargs_type)
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
elif name not in state.values:
|
||||
state.set(name, default, kwargs_type)
|
||||
|
||||
|
||||
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -664,7 +663,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
output += format_input_params(inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
Reference in New Issue
Block a user