mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -45,8 +45,6 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_inputs_short,
|
||||
format_intermediates_short,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -76,139 +74,59 @@ class PipelineState:
|
||||
[`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
|
||||
"""
|
||||
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
intermediates: Dict[str, Any] = field(default_factory=dict)
|
||||
input_kwargs: Dict[str, List[str]] = field(default_factory=dict)
|
||||
intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict)
|
||||
values: Dict[str, Any] = field(default_factory=dict)
|
||||
kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
def set_input(self, key: str, value: Any, kwargs_type: str = None):
|
||||
def set(self, key: str, value: Any, kwargs_type: str = None):
|
||||
"""
|
||||
Add an input to the immutable pipeline state, i.e, pipeline_state.inputs.
|
||||
|
||||
The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
|
||||
set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
|
||||
pipeline block has "guider_kwargs" in its expected_inputs list.
|
||||
Add a value to the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the input
|
||||
value (Any): The input value
|
||||
kwargs_type (str): The kwargs_type with which the input is associated
|
||||
key (str): The key for the value
|
||||
value (Any): The value to store
|
||||
kwargs_type (str): The kwargs_type with which the value is associated
|
||||
"""
|
||||
self.inputs[key] = value
|
||||
self.values[key] = value
|
||||
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.input_kwargs:
|
||||
self.input_kwargs[kwargs_type] = [key]
|
||||
if kwargs_type not in self.kwargs_mapping:
|
||||
self.kwargs_mapping[kwargs_type] = [key]
|
||||
else:
|
||||
self.input_kwargs[kwargs_type].append(key)
|
||||
self.kwargs_mapping[kwargs_type].append(key)
|
||||
|
||||
def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
|
||||
def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]:
|
||||
"""
|
||||
Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates.
|
||||
|
||||
The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
|
||||
if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
|
||||
automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
|
||||
Get one or multiple values from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the intermediate value
|
||||
value (Any): The intermediate value
|
||||
kwargs_type (str): The kwargs_type with which the intermediate value is associated
|
||||
"""
|
||||
self.intermediates[key] = value
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.intermediate_kwargs:
|
||||
self.intermediate_kwargs[kwargs_type] = [key]
|
||||
else:
|
||||
self.intermediate_kwargs[kwargs_type].append(key)
|
||||
|
||||
def get_input(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get an input from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the input
|
||||
default (Any): The default value to return if the input is not found
|
||||
keys (Union[str, List[str]]): Key or list of keys for the values
|
||||
default (Any): The default value to return if not found
|
||||
|
||||
Returns:
|
||||
Any: The input value
|
||||
Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list
|
||||
"""
|
||||
value = self.inputs.get(key, default)
|
||||
if value is not None:
|
||||
return deepcopy(value)
|
||||
if isinstance(keys, str):
|
||||
return self.values.get(keys, default)
|
||||
return {key: self.values.get(key, default) for key in keys}
|
||||
|
||||
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple inputs from the pipeline state.
|
||||
|
||||
Args:
|
||||
keys (List[str]): The keys for the inputs
|
||||
default (Any): The default value to return if the input is not found
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of inputs with matching keys
|
||||
"""
|
||||
return {key: self.inputs.get(key, default) for key in keys}
|
||||
|
||||
def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all inputs with matching kwargs_type.
|
||||
Get all values with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of inputs with matching kwargs_type
|
||||
Dict[str, Any]: Dictionary of values with matching kwargs_type
|
||||
"""
|
||||
input_names = self.input_kwargs.get(kwargs_type, [])
|
||||
return self.get_inputs(input_names)
|
||||
|
||||
def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all intermediates with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
|
||||
"""
|
||||
intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
|
||||
return self.get_intermediates(intermediate_names)
|
||||
|
||||
def get_intermediate(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get an intermediate value from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the intermediate value
|
||||
default (Any): The default value to return if the intermediate value is not found
|
||||
|
||||
Returns:
|
||||
Any: The intermediate value
|
||||
"""
|
||||
return self.intermediates.get(key, default)
|
||||
|
||||
def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple intermediate values from the pipeline state.
|
||||
|
||||
Args:
|
||||
keys (List[str]): The keys for the intermediate values
|
||||
default (Any): The default value to return if the intermediate value is not found
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of intermediate values with matching keys
|
||||
"""
|
||||
return {key: self.intermediates.get(key, default) for key in keys}
|
||||
value_names = self.kwargs_mapping.get(kwargs_type, [])
|
||||
return self.get(value_names)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert PipelineState to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all attributes of the PipelineState
|
||||
"""
|
||||
return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates}
|
||||
return {**self.__dict__}
|
||||
|
||||
def __repr__(self):
|
||||
def format_value(v):
|
||||
@@ -219,21 +137,10 @@ class PipelineState:
|
||||
else:
|
||||
return repr(v)
|
||||
|
||||
inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
|
||||
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
|
||||
values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
|
||||
kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
|
||||
|
||||
# Format input_kwargs and intermediate_kwargs
|
||||
input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
|
||||
intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
|
||||
|
||||
return (
|
||||
f"PipelineState(\n"
|
||||
f" inputs={{\n{inputs}\n }},\n"
|
||||
f" intermediates={{\n{intermediates}\n }},\n"
|
||||
f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
|
||||
f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
|
||||
f")"
|
||||
)
|
||||
return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -322,7 +229,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
@@ -334,6 +241,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
@@ -343,8 +258,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
@@ -352,6 +267,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediate_outputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -436,12 +358,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
state_inputs = self.inputs + self.intermediate_inputs
|
||||
state_inputs = self.inputs
|
||||
|
||||
# Check inputs
|
||||
for input_param in state_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
|
||||
value = state.get(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
@@ -451,9 +373,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
|
||||
input_param.kwargs_type
|
||||
)
|
||||
inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
@@ -467,25 +387,30 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
state.set(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
for input_param in self.inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
current_value = state.get(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
if param_name is None:
|
||||
continue
|
||||
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
@@ -553,199 +478,17 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Block is the basic building block of a Modular Pipeline.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
|
||||
expected_components (List[ComponentSpec], optional):
|
||||
A list of components that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
expected_configs (List[ConfigSpec], optional):
|
||||
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
inputs (List[InputParam], optional):
|
||||
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
intermediate_inputs (List[InputParam], optional):
|
||||
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
intermediate_outputs (List[OutputParam], optional):
|
||||
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
outputs (List[OutputParam], optional):
|
||||
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
required_inputs (List[str], optional):
|
||||
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
|
||||
a property in subclasses.
|
||||
required_intermediate_inputs (List[str], optional):
|
||||
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
required_intermediate_outputs (List[str], optional):
|
||||
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
"""
|
||||
|
||||
model_name = None
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
# raise NotImplementedError("description method must be implemented in subclasses")
|
||||
return "TODO: add a description"
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediate_outputs
|
||||
|
||||
# YiYi TODO: is it too easy for user to unintentionally override these properties?
|
||||
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
def _get_required_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
def _get_required_intermediate_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediate_inputs()
|
||||
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split("\n")
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
# Subsequent lines with proper indentation
|
||||
if len(desc_lines) > 1:
|
||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||
desc = "\n".join(desc) + "\n"
|
||||
|
||||
# Components section - use format_components with add_empty_lines=False
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
components = " " + components_str.replace("\n", "\n ")
|
||||
|
||||
# Configs section - use format_configs with add_empty_lines=False
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
configs = " " + configs_str.replace("\n", "\n ")
|
||||
|
||||
# Inputs section
|
||||
inputs_str = format_inputs_short(self.inputs)
|
||||
inputs = "Inputs:\n " + inputs_str
|
||||
|
||||
# Intermediates section
|
||||
intermediates_str = format_intermediates_short(
|
||||
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
|
||||
)
|
||||
intermediates = f"Intermediates:\n{intermediates_str}"
|
||||
|
||||
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
@@ -836,22 +579,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
first_block = next(iter(self.sub_blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
|
||||
|
||||
# Intersect with required inputs from all other blocks
|
||||
for block in list(self.sub_blocks.values())[1:]:
|
||||
block_required = set(getattr(block, "required_intermediate_inputs", set()))
|
||||
required_by_all.intersection_update(block_required)
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -865,18 +592,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required by all the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_intermediate_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
@@ -895,10 +610,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
block = self.trigger_to_block_map.get(None)
|
||||
for input_name in self.block_trigger_inputs:
|
||||
if input_name is not None and state.get_input(input_name) is not None:
|
||||
if input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
elif input_name is not None and state.get_intermediate(input_name) is not None:
|
||||
elif input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
|
||||
@@ -1117,6 +832,34 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
def _get_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
for inp in block.inputs:
|
||||
if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
return inputs
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return self._get_inputs()
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
# Get the first block from the dictionary
|
||||
@@ -1130,65 +873,11 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return self.get_inputs()
|
||||
|
||||
def get_inputs(self):
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required any of the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return self.get_intermediate_inputs()
|
||||
|
||||
def get_intermediate_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
added_inputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
for inp in block.intermediate_inputs:
|
||||
if inp.name not in outputs and inp.name not in added_inputs:
|
||||
inputs.append(inp)
|
||||
added_inputs.add(inp.name)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = []
|
||||
for name, block in self.sub_blocks.items():
|
||||
inp_names = {inp.name for inp in block.intermediate_inputs}
|
||||
inp_names = {inp.name for inp in block.inputs}
|
||||
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
|
||||
# filter out them here so they do not end up as intermediate_outputs
|
||||
if name not in inp_names:
|
||||
@@ -1406,7 +1095,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -1456,16 +1144,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_required_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
@@ -1475,12 +1153,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_expected_components
|
||||
@property
|
||||
@@ -1508,43 +1183,16 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs.append(config)
|
||||
return expected_configs
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_inputs
|
||||
def get_inputs(self):
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
named_inputs.append(("loop", self.loop_inputs))
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required any of the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
|
||||
def inputs(self):
|
||||
return self.get_inputs()
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_inputs
|
||||
@property
|
||||
def intermediate_inputs(self):
|
||||
intermediates = self.get_intermediate_inputs()
|
||||
intermediate_names = [input.name for input in intermediates]
|
||||
for loop_intermediate_input in self.loop_intermediate_inputs:
|
||||
if loop_intermediate_input.name not in intermediate_names:
|
||||
intermediates.append(loop_intermediate_input)
|
||||
return intermediates
|
||||
|
||||
# modified from SequentialPipelineBlocks
|
||||
def get_intermediate_inputs(self):
|
||||
def _get_inputs(self):
|
||||
inputs = []
|
||||
inputs.extend(self.loop_inputs)
|
||||
outputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
for name, block in self.sub_blocks.items():
|
||||
# Add inputs that aren't in outputs yet
|
||||
inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
|
||||
for inp in block.inputs:
|
||||
if inp.name not in outputs and inp not in inputs:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
@@ -1555,8 +1203,20 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
for input_param in inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
|
||||
def inputs(self):
|
||||
return self._get_inputs()
|
||||
|
||||
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
@@ -1574,19 +1234,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: this need to be thought about more
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs
|
||||
@property
|
||||
@@ -1876,96 +1523,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Execute the pipeline by running the pipeline blocks with the given inputs.
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
|
||||
created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
|
||||
"latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Get complete pipeline state
|
||||
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
|
||||
print(state.intermediates) # All intermediate outputs
|
||||
|
||||
# Get specific output
|
||||
image = pipeline(prompt="A beautiful sunset", output="image")
|
||||
|
||||
# Get multiple specific outputs
|
||||
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
|
||||
image, latents = results["image"], results["latents"]
|
||||
|
||||
# Continue from previous state
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
|
||||
`output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediate_inputs:
|
||||
state.set_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.set_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.set_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.blocks.intermediate_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
"""
|
||||
Load from_pretrained components using the loading specs in the config dict.
|
||||
@@ -2784,3 +2341,92 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Execute the pipeline by running the pipeline blocks with the given inputs.
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
|
||||
created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
|
||||
"latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Get complete pipeline state
|
||||
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
|
||||
print(state.intermediates) # All intermediate outputs
|
||||
|
||||
# Get specific output
|
||||
image = pipeline(prompt="A beautiful sunset", output="image")
|
||||
|
||||
# Get multiple specific outputs
|
||||
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
|
||||
image, latents = results["image"], results["latents"]
|
||||
|
||||
# Continue from previous state
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
|
||||
`output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediate_inputs:
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.set(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.values:
|
||||
state.set(name, default, kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
if isinstance(output, str):
|
||||
return state.get(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -691,7 +686,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +721,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user