1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update doc & repr

This commit is contained in:
yiyixuxu
2025-04-22 10:33:03 +02:00
parent 78fca12803
commit 19555e95cc
2 changed files with 200 additions and 203 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Union
from enum import Enum
from ..utils import is_torch_available
@@ -27,4 +27,11 @@ if is_torch_available():
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]
class GuiderType(Enum):
AdaptiveProjectedGuidance=1,
AutoGuidance=2,
ClassifierFreeGuidance=3,
ClassifierFreeZeroStarGuidance=4,
SkipLayerGuidance=5,
SmoothedEnergyGuidance=6,
TangentialClassifierFreeGuidance=7

View File

@@ -336,30 +336,143 @@ def format_params(params: List[Union[InputParam, OutputParam]], header: str = "A
# Then update the original functions to use this combined version:
def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str:
return format_params(input_params, "Args", indent_level, max_line_length)
return format_params(input_params, "Inputs", indent_level, max_line_length)
def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str:
return format_params(output_params, "Returns", indent_level, max_line_length)
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_components(components: List[ComponentSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str:
"""Format a list of ComponentSpec objects into a readable string representation.
def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
Args:
components: List of ComponentSpec objects to format
indent_level: Number of spaces to indent each component line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between components (default: True)
Returns:
A formatted string representing all components
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
if component.default_repo:
if isinstance(component.default_repo, list) and len(component.default_repo) == 2:
repo_info = component.default_repo[0]
subfolder = component.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
repo_info = component.default_repo
component_desc += f" [{repo_info}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
def format_configs(configs: List[ConfigSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str:
"""Format a list of ConfigSpec objects into a readable string representation.
Args:
configs: List of ConfigSpec objects to format
indent_level: Number of spaces to indent each config line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between configs (default: True)
Returns:
A formatted string representing all configs
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None):
"""
Generates a formatted documentation string describing the pipeline block's parameters and structure.
Args:
inputs (List[InputParam]): List of input parameters
intermediates_inputs (List[InputParam]): List of intermediate input parameters
outputs (List[OutputParam]): List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
expected_components (List[ComponentSpec], *optional*): List of expected components
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
Returns:
str: A formatted string containing information about call parameters, intermediate inputs/outputs,
and final intermediate outputs.
str: A formatted string containing information about components, configs, call parameters,
intermediate inputs/outputs, and final outputs.
"""
output = ""
# Add class name if provided
if class_name:
output += f"class {class_name}\n\n"
# Add description
if description:
desc_lines = description.strip().split('\n')
aligned_desc = '\n'.join(' ' + line for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
@@ -440,31 +553,15 @@ class PipelineBlock:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section - focus only on expected components
# Components section - use format_components with add_empty_lines=False
expected_components = getattr(self, "expected_components", [])
expected_components_str_list = []
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
components = " " + components_str.replace("\n", "\n ")
for component_spec in expected_components:
component_str = f" - {component_spec.name} ({component_spec.type_hint})"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
expected_components_str_list.append(component_str)
components = "Components:\n" + "\n".join(expected_components_str_list)
# Configs section - focus only on expected configs
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs))
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
configs = " " + configs_str.replace("\n", "\n ")
# Inputs section
inputs_str = format_inputs_short(self.inputs)
@@ -478,8 +575,8 @@ class PipelineBlock:
f"{class_name}(\n"
f" Class: {base_class}\n"
f"{desc}"
f" {components}\n"
f" {configs}\n"
f"{components}\n"
f"{configs}\n"
f" {inputs}\n"
f" {intermediates}\n"
f")"
@@ -488,7 +585,15 @@ class PipelineBlock:
@property
def doc(self):
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs
)
def get_block_state(self, state: PipelineState) -> dict:
@@ -796,32 +901,25 @@ class AutoPipelineBlocks:
# Components section - focus only on expected components
expected_components = getattr(self, "expected_components", [])
expected_components_str_list = []
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
for component_spec in expected_components:
component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
expected_components_str_list.append(component_str)
components_str = " Components:\n" + "\n".join(expected_components_str_list)
# Configs section - focus only on expected configs
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name))
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
# Blocks section
# Inputs and outputs section - moved up
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
intermediates_str = (
" Intermediates:\n"
f"{intermediates_str}\n"
f" - final outputs: {', '.join(outputs)}"
)
# Blocks section - moved to the end with simplified format
blocks_str = " Blocks:\n"
for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block
@@ -846,52 +944,31 @@ class AutoPipelineBlocks:
indented_desc = desc_lines[0]
if len(desc_lines) > 1:
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n"
# Format inputs
inputs_str = format_inputs_short(block.inputs)
blocks_str += f" inputs: {inputs_str}\n"
# Format intermediates
intermediates_str = format_intermediates_short(
block.intermediates_inputs,
block.required_intermediates_inputs,
block.intermediates_outputs
)
if intermediates_str != " (none)":
blocks_str += " intermediates:\n"
indented_intermediates = "\n".join(
" " + line for line in intermediates_str.split("\n")
)
blocks_str += f"{indented_intermediates}\n"
blocks_str += "\n"
# Inputs and outputs section
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
intermediates_str = (
"\n Intermediates:\n"
f"{intermediates_str}\n"
f" - final outputs: {', '.join(outputs)}"
)
blocks_str += f" Description: {indented_desc}\n\n"
return (
f"{header}\n"
f"{desc}"
f"{components_str}\n"
f"{configs_str}\n"
f"{blocks_str}\n"
f"{inputs_str}\n"
f"{intermediates_str}\n"
f"{blocks_str}"
f")"
)
@property
def doc(self):
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs
)
class SequentialPipelineBlocks:
"""
@@ -1166,34 +1243,27 @@ class SequentialPipelineBlocks:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section - focus only on expected components
# Components section - use format_components with add_empty_lines=False
expected_components = getattr(self, "expected_components", [])
expected_components_str_list = []
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
for component_spec in expected_components:
component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})"
# Add repo info if available
if component_spec.default_repo:
if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2:
repo_info = component_spec.default_repo[0]
subfolder = component_spec.default_repo[1]
if subfolder:
repo_info += f", subfolder={subfolder}"
else:
repo_info = component_spec.default_repo
component_str += f" [{repo_info}]"
expected_components_str_list.append(component_str)
components_str = " Components:\n" + "\n".join(expected_components_str_list)
# Configs section - focus only on expected configs
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name))
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
# Blocks section
# Inputs and outputs section - moved up
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
intermediates_str = (
" Intermediates:\n"
f"{intermediates_str}\n"
f" - final outputs: {', '.join(outputs)}"
)
# Blocks section - moved to the end with simplified format
blocks_str = " Blocks:\n"
for i, (name, block) in enumerate(self.blocks.items()):
# Get trigger input for this block
@@ -1218,53 +1288,31 @@ class SequentialPipelineBlocks:
indented_desc = desc_lines[0]
if len(desc_lines) > 1:
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n"
# Format inputs
inputs_str = format_inputs_short(block.inputs)
blocks_str += f" inputs: {inputs_str}\n"
# Format intermediates
intermediates_str = format_intermediates_short(
block.intermediates_inputs,
block.required_intermediates_inputs,
block.intermediates_outputs
)
if intermediates_str != " (none)":
blocks_str += " intermediates:\n"
indented_intermediates = "\n".join(
" " + line for line in intermediates_str.split("\n")
)
blocks_str += f"{indented_intermediates}\n"
blocks_str += "\n"
# Inputs and outputs section
inputs_str = format_inputs_short(self.inputs)
inputs_str = " Inputs:\n " + inputs_str
outputs = [out.name for out in self.outputs]
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
intermediates_str = (
"\n Intermediates:\n"
f"{intermediates_str}\n"
f" - final outputs: {', '.join(outputs)}"
)
blocks_str += f" Description: {indented_desc}\n\n"
return (
f"{header}\n"
f"{desc}"
f"{components_str}\n"
f"{configs_str}\n"
f"{blocks_str}\n"
f"{inputs_str}\n"
f"{intermediates_str}\n"
f"{blocks_str}"
f")"
)
@property
def doc(self):
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs
)
class ModularPipeline(ConfigMixin):
"""
@@ -1486,64 +1534,6 @@ class ModularPipeline(ConfigMixin):
params[input_param.name] = input_param.default
return params
# def __repr__(self):
# output = "ModularPipeline:\n"
# output += "==============================\n\n"
# block = self.pipeline_block
# # List the pipeline block structure first
# output += "Pipeline Block:\n"
# output += "--------------\n"
# if hasattr(block, "blocks"):
# output += f"{block.__class__.__name__}\n"
# base_class = block.__class__.__bases__[0].__name__
# output += f" (Class: {base_class})\n" if base_class != "object" else "\n"
# for sub_block_name, sub_block in block.blocks.items():
# if hasattr(block, "block_trigger_inputs"):
# trigger_input = block.block_to_trigger_map[sub_block_name]
# trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]"
# output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n"
# else:
# output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n"
# else:
# output += f"{block.__class__.__name__}\n"
# output += "\n"
# # List the components registered in the pipeline
# output += "Registered Components:\n"
# output += "----------------------\n"
# for name, component in self.components.items():
# output += f"{name}: {type(component).__name__}"
# if hasattr(component, "dtype") and hasattr(component, "device"):
# output += f" (dtype={component.dtype}, device={component.device})"
# output += "\n"
# output += "\n"
# # List the configs registered in the pipeline
# output += "Registered Configs:\n"
# output += "------------------\n"
# for name, config in self.config.items():
# output += f"{name}: {config!r}\n"
# output += "\n"
# # Add auto blocks section
# if hasattr(block, "trigger_inputs") and block.trigger_inputs:
# output += "------------------\n"
# output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n"
# output += f"Trigger Inputs: {block.trigger_inputs}\n"
# # Get first trigger input as example
# example_input = next(t for t in block.trigger_inputs if t is not None)
# output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
# output += "Check `.doc` of returned object for more information.\n\n"
# # List the call parameters
# full_doc = self.pipeline_block.doc
# if "------------------------" in full_doc:
# full_doc = full_doc.split("------------------------")[0].rstrip()
# output += full_doc
# return output
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to