mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -45,8 +45,6 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_inputs_short,
|
||||
format_intermediates_short,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -142,12 +140,8 @@ class PipelineState:
|
||||
values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
|
||||
kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
|
||||
|
||||
return (
|
||||
f"PipelineState(\n"
|
||||
f" values={{\n{values_str}\n }},\n"
|
||||
f" kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n"
|
||||
f")"
|
||||
)
|
||||
return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockState:
|
||||
@@ -402,20 +396,21 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
current_value = state.get(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
elif input_param.kwargs_type:
|
||||
import ipdb; ipdb.set_trace()
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
try:
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
except:
|
||||
import ipdb; ipdb.set_trace()
|
||||
if param_name is None:
|
||||
continue
|
||||
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
@@ -496,200 +491,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Block is the basic building block of a Modular Pipeline.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
|
||||
expected_components (List[ComponentSpec], optional):
|
||||
A list of components that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
expected_configs (List[ConfigSpec], optional):
|
||||
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
inputs (List[InputParam], optional):
|
||||
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
intermediate_inputs (List[InputParam], optional):
|
||||
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
intermediate_outputs (List[OutputParam], optional):
|
||||
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
outputs (List[OutputParam], optional):
|
||||
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
required_inputs (List[str], optional):
|
||||
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
|
||||
a property in subclasses.
|
||||
required_intermediate_inputs (List[str], optional):
|
||||
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
required_intermediate_outputs (List[str], optional):
|
||||
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
"""
|
||||
|
||||
model_name = None
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
# raise NotImplementedError("description method must be implemented in subclasses")
|
||||
return "TODO: add a description"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediate_outputs
|
||||
|
||||
# YiYi TODO: is it too easy for user to unintentionally override these properties?
|
||||
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
def _get_required_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
def _get_required_intermediate_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediate_inputs()
|
||||
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split("\n")
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
# Subsequent lines with proper indentation
|
||||
if len(desc_lines) > 1:
|
||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||
desc = "\n".join(desc) + "\n"
|
||||
|
||||
# Components section - use format_components with add_empty_lines=False
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
components = " " + components_str.replace("\n", "\n ")
|
||||
|
||||
# Configs section - use format_configs with add_empty_lines=False
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
configs = " " + configs_str.replace("\n", "\n ")
|
||||
|
||||
# Inputs section
|
||||
inputs_str = format_inputs_short(self.inputs)
|
||||
inputs = "Inputs:\n " + inputs_str
|
||||
|
||||
# Intermediates section
|
||||
intermediates_str = format_intermediates_short(
|
||||
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
|
||||
)
|
||||
intermediates = f"Intermediates:\n{intermediates_str}"
|
||||
|
||||
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the inputs.
|
||||
@@ -1042,7 +843,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
@@ -61,7 +61,7 @@ class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user