mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add block state will also make sure modifed intermediates_inputs will be updated
This commit is contained in:
@@ -282,7 +282,7 @@ class ModularPipelineMixin:
|
||||
state = PipelineState()
|
||||
|
||||
if not hasattr(self, "loader"):
|
||||
logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
|
||||
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
|
||||
self.loader = None
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
@@ -313,7 +313,7 @@ class ModularPipelineMixin:
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
@@ -373,7 +373,6 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
return []
|
||||
|
||||
|
||||
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
@@ -389,13 +388,16 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediates_outputs
|
||||
|
||||
# YiYi TODO: is it too easy for user to unintentionally override these properties?
|
||||
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self.intermediates_outputs
|
||||
return self._get_outputs()
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
def _get_required_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.inputs:
|
||||
if input_param.required:
|
||||
@@ -403,13 +405,23 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
def required_inputs(self) -> List[str]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
|
||||
def _get_required_intermediates_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.intermediates_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediates_inputs()
|
||||
|
||||
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||
@@ -521,6 +533,30 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
if hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediates_kwargs.items():
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(param_name, param, input_param.kwargs_type)
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
@@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
|
||||
input_param.default is not None and
|
||||
current_param.default != input_param.default):
|
||||
warnings.warn(
|
||||
f"Multiple different default values found for input '{input_param.name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
|
||||
f"Multiple different default values found for input '{input_name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||
)
|
||||
if current_param.default is None and input_param.default is not None:
|
||||
combined_dict[input_param.name] = input_param
|
||||
value_sources[input_param.name] = block_name
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
else:
|
||||
combined_dict[input_param.name] = input_param
|
||||
value_sources[input_param.name] = block_name
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@@ -661,7 +697,9 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
required_by_all.intersection_update(block_required)
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
first_block = next(iter(self.blocks.values()))
|
||||
@@ -838,14 +876,21 @@ class AutoPipelineBlocks(ModularPipelineMixin):
|
||||
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
return (
|
||||
f"{header}\n"
|
||||
f"{desc}\n\n"
|
||||
f"{components_str}\n\n"
|
||||
f"{configs_str}\n\n"
|
||||
f"{blocks_str}"
|
||||
f")"
|
||||
)
|
||||
# Build the representation with conditional sections
|
||||
result = f"{header}\n{desc}"
|
||||
|
||||
# Only add components section if it has content
|
||||
if components_str.strip():
|
||||
result += f"\n\n{components_str}"
|
||||
|
||||
# Only add configs section if it has content
|
||||
if configs_str.strip():
|
||||
result += f"\n\n{configs_str}"
|
||||
|
||||
# Always add blocks section
|
||||
result += f"\n\n{blocks_str})"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@property
|
||||
@@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
block_classes = []
|
||||
block_names = []
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
return next(iter(self.blocks.values())).model_name
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return ""
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
return next(iter(self.blocks.values())).model_name
|
||||
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
@@ -929,6 +976,8 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
required_intermediates_inputs = []
|
||||
@@ -960,11 +1009,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
def get_intermediates_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
added_inputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
|
||||
for inp in block.intermediates_inputs:
|
||||
if inp.name not in outputs and inp.name not in added_inputs:
|
||||
inputs.append(inp)
|
||||
added_inputs.add(inp.name)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
@@ -1176,14 +1229,21 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
return (
|
||||
f"{header}\n"
|
||||
f"{desc}\n\n"
|
||||
f"{components_str}\n\n"
|
||||
f"{configs_str}\n\n"
|
||||
f"{blocks_str}"
|
||||
f")"
|
||||
)
|
||||
# Build the representation with conditional sections
|
||||
result = f"{header}\n{desc}"
|
||||
|
||||
# Only add components section if it has content
|
||||
if components_str.strip():
|
||||
result += f"\n\n{components_str}"
|
||||
|
||||
# Only add configs section if it has content
|
||||
if configs_str.strip():
|
||||
result += f"\n\n{configs_str}"
|
||||
|
||||
# Always add blocks section
|
||||
result += f"\n\n{blocks_str})"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@property
|
||||
@@ -1348,7 +1408,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
required_intermediates_inputs = []
|
||||
@@ -1384,6 +1445,22 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
blocks[block_name] = block_cls()
|
||||
self.blocks = blocks
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
|
||||
"""Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
|
||||
|
||||
Args:
|
||||
blocks_dict: Dictionary mapping block names to block instances
|
||||
|
||||
Returns:
|
||||
A new LoopSequentialPipelineBlocks instance
|
||||
"""
|
||||
instance = cls()
|
||||
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
|
||||
instance.block_names = list(blocks_dict.keys())
|
||||
instance.blocks = blocks_dict
|
||||
return instance
|
||||
|
||||
def loop_step(self, components, state: PipelineState, **kwargs):
|
||||
|
||||
@@ -1455,6 +1532,100 @@ class LoopSequentialPipelineBlocks(ModularPipelineMixin):
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediates_kwargs.items():
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(param_name, param, input_param.kwargs_type)
|
||||
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediates_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs
|
||||
)
|
||||
|
||||
# modified from SequentialPipelineBlocks,
|
||||
#(does not need trigger_inputs related part so removed them,
|
||||
# do not need to support auto block for loop blocks)
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
header = (
|
||||
f"{class_name}(\n Class: {base_class}\n"
|
||||
if base_class and base_class != "object"
|
||||
else f"{class_name}(\n"
|
||||
)
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split('\n')
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
# Subsequent lines with proper indentation
|
||||
if len(desc_lines) > 1:
|
||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||
desc = '\n'.join(desc) + '\n'
|
||||
|
||||
# Components section - focus only on expected components
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
|
||||
# Configs section - use format_configs with add_empty_lines=False
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
|
||||
# Blocks section - moved to the end with simplified format
|
||||
blocks_str = " Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.blocks.items()):
|
||||
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split('\n')
|
||||
indented_desc = desc_lines[0]
|
||||
if len(desc_lines) > 1:
|
||||
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
# Build the representation with conditional sections
|
||||
result = f"{header}\n{desc}"
|
||||
|
||||
# Only add components section if it has content
|
||||
if components_str.strip():
|
||||
result += f"\n\n{components_str}"
|
||||
|
||||
# Only add configs section if it has content
|
||||
if configs_str.strip():
|
||||
result += f"\n\n{configs_str}"
|
||||
|
||||
# Always add blocks section
|
||||
result += f"\n\n{blocks_str})"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
# YiYi TODO:
|
||||
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
||||
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader
|
||||
|
||||
Reference in New Issue
Block a user