1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2025-07-29 21:00:03 +02:00
parent 1db63655e4
commit 496bf0be1b
2 changed files with 14 additions and 213 deletions

View File

@@ -45,8 +45,6 @@ from .modular_pipeline_utils import (
OutputParam,
format_components,
format_configs,
format_inputs_short,
format_intermediates_short,
make_doc_string,
)
@@ -142,12 +140,8 @@ class PipelineState:
values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
return (
f"PipelineState(\n"
f" values={{\n{values_str}\n }},\n"
f" kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n"
f")"
)
return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
@dataclass
class BlockState:
@@ -402,20 +396,21 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
current_value = state.get(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
import ipdb; ipdb.set_trace()
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
try:
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
except:
import ipdb; ipdb.set_trace()
if param_name is None:
continue
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
@@ -496,200 +491,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks):
"""
A Pipeline Block is the basic building block of a Modular Pipeline.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
Args:
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
expected_components (List[ComponentSpec], optional):
A list of components that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
expected_configs (List[ConfigSpec], optional):
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
inputs (List[InputParam], optional):
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
intermediate_inputs (List[InputParam], optional):
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
define as a property in subclasses.
intermediate_outputs (List[OutputParam], optional):
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
define as a property in subclasses.
outputs (List[OutputParam], optional):
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
required_inputs (List[str], optional):
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
a property in subclasses.
required_intermediate_inputs (List[str], optional):
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
override, define as a property in subclasses.
required_intermediate_outputs (List[str], optional):
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
override, define as a property in subclasses.
"""
model_name = None
def __init__(self):
self.sub_blocks = InsertableDict()
@property
def description(self) -> str:
"""Description of the block. Must be implemented by subclasses."""
# raise NotImplementedError("description method must be implemented in subclasses")
return "TODO: add a description"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
"""List of input parameters. Must be implemented by subclasses."""
return []
@property
def intermediate_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
def _get_outputs(self):
return self.intermediate_outputs
# YiYi TODO: is it too easy for user to unintentionally override these properties?
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
@property
def outputs(self) -> List[OutputParam]:
return self._get_outputs()
def _get_required_inputs(self):
input_names = []
for input_param in self.inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@property
def required_inputs(self) -> List[str]:
return self._get_required_inputs()
def _get_required_intermediate_inputs(self):
input_names = []
for input_param in self.intermediate_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediate_inputs(self) -> List[str]:
return self._get_required_intermediate_inputs()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
raise NotImplementedError("__call__ method must be implemented in subclasses")
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
# Format description with proper indentation
desc_lines = self.description.split("\n")
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation
if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = "\n".join(desc) + "\n"
# Components section - use format_components with add_empty_lines=False
expected_components = getattr(self, "expected_components", [])
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
components = " " + components_str.replace("\n", "\n ")
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
configs = " " + configs_str.replace("\n", "\n ")
# Inputs section
inputs_str = format_inputs_short(self.inputs)
inputs = "Inputs:\n " + inputs_str
# Intermediates section
intermediates_str = format_intermediates_short(
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
)
intermediates = f"Intermediates:\n{intermediates_str}"
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
@property
def doc(self):
return make_doc_string(
self.inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs,
)
def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediate_inputs:
if hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(input_param.name, param, input_param.kwargs_type)
for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the inputs.
@@ -1042,7 +843,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
inputs.append(inp)
# Only add outputs if the block cannot be skipped
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
should_add_outputs = False

View File

@@ -61,7 +61,7 @@ class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]
@property