diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 56e95c92b6..a0a905f2b5 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Union - +from enum import Enum from ..utils import is_torch_available @@ -27,4 +27,11 @@ if is_torch_available(): from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] + class GuiderType(Enum): + AdaptiveProjectedGuidance=1, + AutoGuidance=2, + ClassifierFreeGuidance=3, + ClassifierFreeZeroStarGuidance=4, + SkipLayerGuidance=5, + SmoothedEnergyGuidance=6, + TangentialClassifierFreeGuidance=7 diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 785f38cdbf..b896066edf 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -336,30 +336,143 @@ def format_params(params: List[Union[InputParam, OutputParam]], header: str = "A # Then update the original functions to use this combined version: def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) + return format_params(input_params, "Inputs", indent_level, max_line_length) def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) + return format_params(output_params, "Outputs", indent_level, max_line_length) +def format_components(components: List[ComponentSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: + """Format a list of ComponentSpec objects into a readable string representation. -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + if component.default_repo: + if isinstance(component.default_repo, list) and len(component.default_repo) == 2: + repo_info = component.default_repo[0] + subfolder = component.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" + else: + repo_info = component.default_repo + component_desc += f" [{repo_info}]" + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs: List[ConfigSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): """ Generates a formatted documentation string describing the pipeline block's parameters and structure. + Args: + inputs (List[InputParam]): List of input parameters + intermediates_inputs (List[InputParam]): List of intermediate input parameters + outputs (List[OutputParam]): List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. """ output = "" + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description if description: desc_lines = description.strip().split('\n') aligned_desc = '\n'.join(' ' + line for line in desc_lines) output += aligned_desc + "\n\n" + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section output += format_input_params(inputs + intermediates_inputs, indent_level=2) + # Add outputs section output += "\n\n" output += format_output_params(outputs, indent_level=2) @@ -440,31 +553,15 @@ class PipelineBlock: desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") - for component_spec in expected_components: - component_str = f" - {component_spec.name} ({component_spec.type_hint})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components = "Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -478,8 +575,8 @@ class PipelineBlock: f"{class_name}(\n" f" Class: {base_class}\n" f"{desc}" - f" {components}\n" - f" {configs}\n" + f"{components}\n" + f"{configs}\n" f" {inputs}\n" f" {intermediates}\n" f")" @@ -488,7 +585,15 @@ class PipelineBlock: @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) def get_block_state(self, state: PipelineState) -> dict: @@ -796,32 +901,25 @@ class AutoPipelineBlocks: # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Inputs and outputs section - moved up + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + + outputs = [out.name for out in self.outputs] + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = ( + " Intermediates:\n" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" + ) + + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -846,52 +944,31 @@ class AutoPipelineBlocks: indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" f"{desc}" f"{components_str}\n" f"{configs_str}\n" - f"{blocks_str}\n" f"{inputs_str}\n" f"{intermediates_str}\n" + f"{blocks_str}" f")" ) + @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) class SequentialPipelineBlocks: """ @@ -1166,34 +1243,27 @@ class SequentialPipelineBlocks: desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Inputs and outputs section - moved up + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + + outputs = [out.name for out in self.outputs] + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = ( + " Intermediates:\n" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" + ) + + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1218,53 +1288,31 @@ class SequentialPipelineBlocks: indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" f"{desc}" f"{components_str}\n" f"{configs_str}\n" - f"{blocks_str}\n" f"{inputs_str}\n" f"{intermediates_str}\n" + f"{blocks_str}" f")" ) @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) class ModularPipeline(ConfigMixin): """ @@ -1486,64 +1534,6 @@ class ModularPipeline(ConfigMixin): params[input_param.name] = input_param.default return params - # def __repr__(self): - # output = "ModularPipeline:\n" - # output += "==============================\n\n" - - # block = self.pipeline_block - - # # List the pipeline block structure first - # output += "Pipeline Block:\n" - # output += "--------------\n" - # if hasattr(block, "blocks"): - # output += f"{block.__class__.__name__}\n" - # base_class = block.__class__.__bases__[0].__name__ - # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - # for sub_block_name, sub_block in block.blocks.items(): - # if hasattr(block, "block_trigger_inputs"): - # trigger_input = block.block_to_trigger_map[sub_block_name] - # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - # else: - # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - # else: - # output += f"{block.__class__.__name__}\n" - # output += "\n" - - # # List the components registered in the pipeline - # output += "Registered Components:\n" - # output += "----------------------\n" - # for name, component in self.components.items(): - # output += f"{name}: {type(component).__name__}" - # if hasattr(component, "dtype") and hasattr(component, "device"): - # output += f" (dtype={component.dtype}, device={component.device})" - # output += "\n" - # output += "\n" - - # # List the configs registered in the pipeline - # output += "Registered Configs:\n" - # output += "------------------\n" - # for name, config in self.config.items(): - # output += f"{name}: {config!r}\n" - # output += "\n" - - # # Add auto blocks section - # if hasattr(block, "trigger_inputs") and block.trigger_inputs: - # output += "------------------\n" - # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - # output += f"Trigger Inputs: {block.trigger_inputs}\n" - # # Get first trigger input as example - # example_input = next(t for t in block.trigger_inputs if t is not None) - # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - # output += "Check `.doc` of returned object for more information.\n\n" - - # # List the call parameters - # full_doc = self.pipeline_block.doc - # if "------------------------" in full_doc: - # full_doc = full_doc.split("------------------------")[0].rstrip() - # output += full_doc - - # return output # YiYi TODO: try to unify the to method with the one in DiffusionPipeline # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to