1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add block state will also make sure modifed intermediates_inputs will be updated

This commit is contained in:
yiyixuxu
2025-05-12 01:16:42 +02:00
parent 796453cad1
commit 144eae4e0b

View File

@@ -282,7 +282,7 @@ class ModularPipelineMixin:
state = PipelineState()
if not hasattr(self, "loader"):
logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
self.loader = None
# Make a copy of the input kwargs
@@ -313,7 +313,7 @@ class ModularPipelineMixin:
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
@@ -373,7 +373,6 @@ class PipelineBlock(ModularPipelineMixin):
return []
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
@property
def inputs(self) -> List[InputParam]:
"""List of input parameters. Must be implemented by subclasses."""
@@ -389,13 +388,16 @@ class PipelineBlock(ModularPipelineMixin):
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
def _get_outputs(self):
return self.intermediates_outputs
# YiYi TODO: is it too easy for user to unintentionally override these properties?
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
@property
def outputs(self) -> List[OutputParam]:
return self.intermediates_outputs
return self._get_outputs()
@property
def required_inputs(self) -> List[str]:
def _get_required_inputs(self):
input_names = []
for input_param in self.inputs:
if input_param.required:
@@ -403,13 +405,23 @@ class PipelineBlock(ModularPipelineMixin):
return input_names
@property
def required_intermediates_inputs(self) -> List[str]:
def required_inputs(self) -> List[str]:
return self._get_required_inputs()
def _get_required_intermediates_inputs(self):
input_names = []
for input_param in self.intermediates_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
return self._get_required_intermediates_inputs()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
raise NotImplementedError("__call__ method must be implemented in subclasses")
@@ -521,6 +533,30 @@ class PipelineBlock(ModularPipelineMixin):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediates_inputs:
if hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
for input_param in self.intermediates_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediates_kwargs.items():
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(param_name, param, input_param.kwargs_type)
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
@@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
input_param.default is not None and
current_param.default != input_param.default):
warnings.warn(
f"Multiple different default values found for input '{input_param.name}': "
f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_param.name] = input_param
value_sources[input_param.name] = block_name
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_param.name] = input_param
value_sources[input_param.name] = block_name
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
@@ -661,7 +697,9 @@ class AutoPipelineBlocks(ModularPipelineMixin):
required_by_all.intersection_update(block_required)
return list(required_by_all)
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
first_block = next(iter(self.blocks.values()))
@@ -838,14 +876,21 @@ class AutoPipelineBlocks(ModularPipelineMixin):
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
return (
f"{header}\n"
f"{desc}\n\n"
f"{components_str}\n\n"
f"{configs_str}\n\n"
f"{blocks_str}"
f")"
)
# Build the representation with conditional sections
result = f"{header}\n{desc}"
# Only add components section if it has content
if components_str.strip():
result += f"\n\n{components_str}"
# Only add configs section if it has content
if configs_str.strip():
result += f"\n\n{configs_str}"
# Always add blocks section
result += f"\n\n{blocks_str})"
return result
@property
@@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
block_classes = []
block_names = []
@property
def model_name(self):
return next(iter(self.blocks.values())).model_name
@property
def description(self):
return ""
@property
def model_name(self):
return next(iter(self.blocks.values())).model_name
@property
def expected_components(self):
@@ -929,6 +976,8 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
return list(required_by_any)
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
required_intermediates_inputs = []
@@ -960,11 +1009,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
def get_intermediates_inputs(self):
inputs = []
outputs = set()
added_inputs = set()
# Go through all blocks in order
for block in self.blocks.values():
# Add inputs that aren't in outputs yet
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
for inp in block.intermediates_inputs:
if inp.name not in outputs and inp.name not in added_inputs:
inputs.append(inp)
added_inputs.add(inp.name)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
@@ -1176,14 +1229,21 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
return (
f"{header}\n"
f"{desc}\n\n"
f"{components_str}\n\n"
f"{configs_str}\n\n"
f"{blocks_str}"
f")"
)
# Build the representation with conditional sections
result = f"{header}\n{desc}"
# Only add components section if it has content
if components_str.strip():
result += f"\n\n{components_str}"
# Only add configs section if it has content
if configs_str.strip():
result += f"\n\n{configs_str}"
# Always add blocks section
result += f"\n\n{blocks_str})"
return result
@property
@@ -1348,7 +1408,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
return list(required_by_any)
# modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
required_intermediates_inputs = []
@@ -1384,6 +1445,22 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
for block_name, block_cls in zip(self.block_names, self.block_classes):
blocks[block_name] = block_cls()
self.blocks = blocks
@classmethod
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
"""Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
Args:
blocks_dict: Dictionary mapping block names to block instances
Returns:
A new LoopSequentialPipelineBlocks instance
"""
instance = cls()
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
instance.block_names = list(blocks_dict.keys())
instance.blocks = blocks_dict
return instance
def loop_step(self, components, state: PipelineState, **kwargs):
@@ -1455,6 +1532,100 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediates_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediates_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(param_name, param, input_param.kwargs_type)
@property
def doc(self):
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs
)
# modified from SequentialPipelineBlocks,
#(does not need trigger_inputs related part so removed them,
# do not need to support auto block for loop blocks)
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
header = (
f"{class_name}(\n Class: {base_class}\n"
if base_class and base_class != "object"
else f"{class_name}(\n"
)
# Format description with proper indentation
desc_lines = self.description.split('\n')
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation
if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = '\n'.join(desc) + '\n'
# Components section - focus only on expected components
expected_components = getattr(self, "expected_components", [])
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
# Blocks section - moved to the end with simplified format
blocks_str = " Blocks:\n"
for i, (name, block) in enumerate(self.blocks.items()):
# For SequentialPipelineBlocks, show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split('\n')
indented_desc = desc_lines[0]
if len(desc_lines) > 1:
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
# Build the representation with conditional sections
result = f"{header}\n{desc}"
# Only add components section if it has content
if components_str.strip():
result += f"\n\n{components_str}"
# Only add configs section if it has content
if configs_str.strip():
result += f"\n\n{configs_str}"
# Always add blocks section
result += f"\n\n{blocks_str})"
return result
# YiYi TODO:
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader