mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
intermediates_inputs -> intermediate_inputs; component_manager -> components_manager, and more
This commit is contained in:
@@ -284,12 +284,12 @@ class ComponentsManager:
|
||||
if comp == component:
|
||||
comp_name = self._id_to_name(comp_id)
|
||||
if comp_name == name:
|
||||
logger.warning(f"component '{name}' already exists as '{comp_id}'")
|
||||
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
|
||||
component_id = comp_id
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
|
||||
f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
|
||||
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ class ComponentsManager:
|
||||
if components_with_same_load_id:
|
||||
existing = ", ".join(components_with_same_load_id)
|
||||
logger.warning(
|
||||
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||
f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
|
||||
)
|
||||
|
||||
@@ -315,12 +315,12 @@ class ComponentsManager:
|
||||
if component_id not in self.collections[collection]:
|
||||
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
|
||||
for comp_id in comp_ids_in_collection:
|
||||
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
|
||||
logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}")
|
||||
self.remove(comp_id)
|
||||
self.collections[collection].add(component_id)
|
||||
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
|
||||
logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}")
|
||||
else:
|
||||
logger.info(f"Added component '{name}' as '{component_id}'")
|
||||
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
@@ -659,6 +659,10 @@ class ComponentsManager:
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
# Handle empty components case
|
||||
if not self.components:
|
||||
return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50
|
||||
|
||||
# Helper to get simple name without UUID
|
||||
def get_simple_name(name):
|
||||
# Extract the base name by splitting on underscore and taking first part
|
||||
@@ -802,51 +806,6 @@ class ComponentsManager:
|
||||
|
||||
return output
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Load components from a pretrained model and add them to the manager.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
|
||||
prefix (str, optional): Prefix to add to all component names loaded from this model.
|
||||
If provided, components will be named as "{prefix}_{component_name}"
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
# YiYi TODO: extend AutoModel to support non-diffusers models
|
||||
if subfolder:
|
||||
from ..models import AutoModel
|
||||
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
|
||||
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
|
||||
def get_one(
|
||||
self,
|
||||
component_id: Optional[str] = None,
|
||||
|
||||
@@ -126,7 +126,7 @@ class PipelineState:
|
||||
input_names = self.input_kwargs.get(kwargs_type, [])
|
||||
return self.get_inputs(input_names)
|
||||
|
||||
def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all intermediates with matching kwargs_type.
|
||||
|
||||
@@ -325,7 +325,7 @@ class ModularPipelineBlocks(ConfigMixin):
|
||||
def init_pipeline(
|
||||
self,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
component_manager: Optional[ComponentsManager] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -344,10 +344,10 @@ class ModularPipelineBlocks(ConfigMixin):
|
||||
loader = loader_class(
|
||||
specs=specs,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
component_manager=component_manager,
|
||||
components_manager=components_manager,
|
||||
collection=collection,
|
||||
)
|
||||
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
|
||||
modular_pipeline = ModularPipeline(blocks=deepcopy(self), loader=loader)
|
||||
return modular_pipeline
|
||||
|
||||
|
||||
@@ -374,17 +374,17 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediates_outputs
|
||||
return self.intermediate_outputs
|
||||
|
||||
# YiYi TODO: is it too easy for user to unintentionally override these properties?
|
||||
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
|
||||
@@ -403,9 +403,9 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
def required_inputs(self) -> List[str]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
def _get_required_intermediates_inputs(self):
|
||||
def _get_required_intermediate_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
@@ -413,8 +413,8 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediates_inputs()
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediate_inputs()
|
||||
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||
@@ -449,7 +449,7 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
|
||||
# Intermediates section
|
||||
intermediates_str = format_intermediates_short(
|
||||
self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs
|
||||
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
|
||||
)
|
||||
intermediates = f"Intermediates:\n{intermediates_str}"
|
||||
|
||||
@@ -459,7 +459,7 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediates_inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -492,7 +492,7 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
# Check intermediates
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
@@ -503,9 +503,9 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
if intermediates_kwargs:
|
||||
for k, v in intermediates_kwargs.items():
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
if intermediate_kwargs:
|
||||
for k, v in intermediate_kwargs.items():
|
||||
if v is not None:
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
@@ -513,13 +513,13 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
return BlockState(**data)
|
||||
|
||||
def add_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediates_outputs:
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
@@ -527,7 +527,7 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
@@ -537,8 +537,8 @@ class PipelineBlock(ModularPipelineBlocks):
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediates_kwargs.items():
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.add_intermediate(param_name, param, input_param.kwargs_type)
|
||||
@@ -610,6 +610,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
# YiYi TODO: change blocks attribute to a different name, so it is not confused with the blocks attribute in ModularPipeline
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A class that automatically selects a block to run based on the inputs.
|
||||
@@ -692,15 +693,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
first_block = next(iter(self.blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_intermediates_inputs", set()))
|
||||
required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
|
||||
|
||||
# Intersect with required inputs from all other blocks
|
||||
for block in list(self.blocks.values())[1:]:
|
||||
block_required = set(getattr(block, "required_intermediates_inputs", set()))
|
||||
block_required = set(getattr(block, "required_intermediate_inputs", set()))
|
||||
required_by_all.intersection_update(block_required)
|
||||
|
||||
return list(required_by_all)
|
||||
@@ -719,20 +720,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()]
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
named_inputs = [(name, block.intermediate_inputs) for name, block in self.blocks.items()]
|
||||
combined_inputs = combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required by all the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_intermediates_inputs:
|
||||
if input_param.name in self.required_intermediate_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()]
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@@ -885,7 +886,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediates_inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -975,12 +976,12 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
required_intermediates_inputs = []
|
||||
for input_param in self.intermediates_inputs:
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediates_inputs.append(input_param.name)
|
||||
return required_intermediates_inputs
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
@@ -999,10 +1000,10 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
return self.get_intermediates_inputs()
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return self.get_intermediate_inputs()
|
||||
|
||||
def get_intermediates_inputs(self):
|
||||
def get_intermediate_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
added_inputs = set()
|
||||
@@ -1010,7 +1011,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Go through all blocks in order
|
||||
for block in self.blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
for inp in block.intermediates_inputs:
|
||||
for inp in block.intermediate_inputs:
|
||||
if inp.name not in outputs and inp.name not in added_inputs:
|
||||
inputs.append(inp)
|
||||
added_inputs.add(inp.name)
|
||||
@@ -1022,27 +1023,27 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
|
||||
outputs.update(block_intermediates_outputs)
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = []
|
||||
for name, block in self.blocks.items():
|
||||
inp_names = {inp.name for inp in block.intermediates_inputs}
|
||||
# so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
|
||||
# filter out them here so they do not end up as intermediates_outputs
|
||||
inp_names = {inp.name for inp in block.intermediate_inputs}
|
||||
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
|
||||
# filter out them here so they do not end up as intermediate_outputs
|
||||
if name not in inp_names:
|
||||
named_outputs.append((name, block.intermediates_outputs))
|
||||
named_outputs.append((name, block.intermediate_outputs))
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
# YiYi TODO: I think we can remove the outputs property
|
||||
@property
|
||||
def outputs(self) -> List[str]:
|
||||
# return next(reversed(self.blocks.values())).intermediates_outputs
|
||||
return self.intermediates_outputs
|
||||
# return next(reversed(self.blocks.values())).intermediate_outputs
|
||||
return self.intermediate_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -1248,7 +1249,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediates_inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -1287,12 +1288,12 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediates_inputs(self) -> List[InputParam]:
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediates_outputs(self) -> List[OutputParam]:
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@@ -1305,9 +1306,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def loop_required_intermediates_inputs(self) -> List[str]:
|
||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
for input_param in self.loop_intermediates_inputs:
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
@@ -1356,25 +1357,25 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def inputs(self):
|
||||
return self.get_inputs()
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediates_inputs
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_inputs
|
||||
@property
|
||||
def intermediates_inputs(self):
|
||||
intermediates = self.get_intermediates_inputs()
|
||||
def intermediate_inputs(self):
|
||||
intermediates = self.get_intermediate_inputs()
|
||||
intermediate_names = [input.name for input in intermediates]
|
||||
for loop_intermediate_input in self.loop_intermediates_inputs:
|
||||
for loop_intermediate_input in self.loop_intermediate_inputs:
|
||||
if loop_intermediate_input.name not in intermediate_names:
|
||||
intermediates.append(loop_intermediate_input)
|
||||
return intermediates
|
||||
|
||||
# modified from SequentialPipelineBlocks
|
||||
def get_intermediates_inputs(self):
|
||||
def get_intermediate_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
|
||||
inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
@@ -1383,8 +1384,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
|
||||
outputs.update(block_intermediates_outputs)
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
return inputs
|
||||
|
||||
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
|
||||
@@ -1407,23 +1408,23 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediates_inputs(self) -> List[str]:
|
||||
required_intermediates_inputs = []
|
||||
for input_param in self.intermediates_inputs:
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediates_inputs.append(input_param.name)
|
||||
for input_param in self.loop_intermediates_inputs:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediates_inputs.append(input_param.name)
|
||||
return required_intermediates_inputs
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: this need to be thought about more
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediates_outputs
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()]
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
for output in self.loop_intermediates_outputs:
|
||||
for output in self.loop_intermediate_outputs:
|
||||
if output.name not in {output.name for output in combined_outputs}:
|
||||
combined_outputs.append(output)
|
||||
return combined_outputs
|
||||
@@ -1431,7 +1432,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO: this need to be thought about more
|
||||
@property
|
||||
def outputs(self) -> List[str]:
|
||||
return next(reversed(self.blocks.values())).intermediates_outputs
|
||||
return next(reversed(self.blocks.values())).intermediate_outputs
|
||||
|
||||
def __init__(self):
|
||||
blocks = InsertableOrderedDict()
|
||||
@@ -1497,7 +1498,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
data[input_param.kwargs_type][k] = v
|
||||
|
||||
# Check intermediates
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_intermediate(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
@@ -1508,9 +1509,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# if kwargs_type is provided, get all intermediates with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
if intermediates_kwargs:
|
||||
for k, v in intermediates_kwargs.items():
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
if intermediate_kwargs:
|
||||
for k, v in intermediate_kwargs.items():
|
||||
if v is not None:
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
@@ -1518,13 +1519,13 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return BlockState(**data)
|
||||
|
||||
def add_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediates_outputs:
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediates_inputs:
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
@@ -1534,8 +1535,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediates_kwargs.items():
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
param = getattr(block_state, param_name)
|
||||
@@ -1546,7 +1547,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediates_inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -1660,7 +1661,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or
|
||||
loader.update(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"])
|
||||
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"])
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
@@ -1710,8 +1711,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
if not is_registered:
|
||||
self.register_to_config(**register_dict)
|
||||
setattr(self, name, module)
|
||||
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
|
||||
self._component_manager.add(name, module, self._collection)
|
||||
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
|
||||
self._components_manager.add(name, module, self._collection)
|
||||
continue
|
||||
|
||||
current_module = getattr(self, name, None)
|
||||
@@ -1745,22 +1746,22 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
# finally set models
|
||||
setattr(self, name, module)
|
||||
# add to component manager if one is attached
|
||||
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
|
||||
self._component_manager.add(name, module, self._collection)
|
||||
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
|
||||
self._components_manager.add(name, module, self._collection)
|
||||
|
||||
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
|
||||
def __init__(
|
||||
self,
|
||||
specs: List[Union[ComponentSpec, ConfigSpec]],
|
||||
pretrained_model_name_or_path: Optional[str] = None,
|
||||
component_manager: Optional[ComponentsManager] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the loader with a list of component specs and config specs.
|
||||
"""
|
||||
self._component_manager = component_manager
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)}
|
||||
@@ -1848,6 +1849,10 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
return module.dtype
|
||||
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def component_names(self) -> List[str]:
|
||||
return list(self.components.keys())
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
@@ -1958,12 +1963,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
def load(self, component_names: Optional[List[str]] = None, **kwargs):
|
||||
def load(self, names: Optional[List[str]] = None, **kwargs):
|
||||
"""
|
||||
Load selectedcomponents from specs.
|
||||
Load selected components from specs.
|
||||
|
||||
Args:
|
||||
component_names: List of component names to load
|
||||
names: List of component names to load
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
@@ -1971,19 +1976,19 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
`variant`, `revision`, etc.
|
||||
"""
|
||||
# if not specific name, load all the components with default_creation_method == "from_pretrained"
|
||||
if component_names is None:
|
||||
component_names = [
|
||||
if names is None:
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
elif not isinstance(component_names, list):
|
||||
component_names = [component_names]
|
||||
elif not isinstance(names, list):
|
||||
names = [names]
|
||||
|
||||
components_to_load = {name for name in component_names if name in self._component_specs}
|
||||
unknown_component_names = {name for name in component_names if name not in self._component_specs}
|
||||
if len(unknown_component_names) > 0:
|
||||
logger.warning(f"Unknown components will be ignored: {unknown_component_names}")
|
||||
components_to_load = {name for name in names if name in self._component_specs}
|
||||
unknown_names = {name for name in names if name not in self._component_specs}
|
||||
if len(unknown_names) > 0:
|
||||
logger.warning(f"Unknown components will be ignored: {unknown_names}")
|
||||
|
||||
components_to_register = {}
|
||||
for name in components_to_load:
|
||||
@@ -2240,7 +2245,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
spec_only: bool = True,
|
||||
component_manager: Optional[ComponentsManager] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -2261,7 +2266,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
|
||||
elif name in expected_config:
|
||||
config_specs.append(ConfigSpec(name=name, default=value))
|
||||
|
||||
return cls(component_specs + config_specs, component_manager=component_manager, collection=collection)
|
||||
return cls(component_specs + config_specs, components_manager=components_manager, collection=collection)
|
||||
|
||||
@staticmethod
|
||||
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
|
||||
@@ -2370,20 +2375,20 @@ class ModularPipeline:
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
|
||||
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediates_inputs:
|
||||
if name not in intermediate_inputs:
|
||||
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.add_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.add_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.blocks.intermediates_inputs:
|
||||
for expected_intermediate_param in self.blocks.intermediate_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
@@ -2412,8 +2417,8 @@ class ModularPipeline:
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
|
||||
self.loader.load(component_names=component_names, **kwargs)
|
||||
def load_components(self, names: Optional[List[str]] = None, **kwargs):
|
||||
self.loader.load(names=names, **kwargs)
|
||||
|
||||
def update_components(self, **kwargs):
|
||||
self.loader.update(**kwargs)
|
||||
@@ -2424,7 +2429,7 @@ class ModularPipeline:
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
component_manager: Optional[ComponentsManager] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -2432,7 +2437,7 @@ class ModularPipeline:
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
pipeline = blocks.init_pipeline(
|
||||
pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs
|
||||
pretrained_model_name_or_path, components_manager=components_manager, collection=collection, **kwargs
|
||||
)
|
||||
return pipeline
|
||||
|
||||
|
||||
@@ -49,7 +49,13 @@ class InsertableOrderedDict(OrderedDict):
|
||||
|
||||
items = []
|
||||
for i, (key, value) in enumerate(self.items()):
|
||||
items.append(f"{i}: ({repr(key)}, {repr(value)})")
|
||||
if isinstance(value, type):
|
||||
# For classes, show class name and <class ...>
|
||||
obj_repr = f"<class '{value.__module__}.{value.__name__}'>"
|
||||
else:
|
||||
# For objects (instances) and other types, show class name and module
|
||||
obj_repr = f"<obj '{value.__class__.__module__}.{value.__class__.__name__}'>"
|
||||
items.append(f"{i}: ({repr(key)}, {obj_repr})")
|
||||
|
||||
return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])"
|
||||
|
||||
@@ -260,11 +266,11 @@ class ConfigSpec:
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediates_inputs
|
||||
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediate_inputs
|
||||
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
|
||||
# -> should we use different class for inputs and intermediates_inputs?
|
||||
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
|
||||
# -> should we use different class for inputs and intermediate_inputs?
|
||||
@dataclass
|
||||
class InputParam:
|
||||
"""Specification for an input parameter."""
|
||||
@@ -324,14 +330,14 @@ def format_inputs_short(inputs):
|
||||
return inputs_str
|
||||
|
||||
|
||||
def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs):
|
||||
def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
|
||||
"""
|
||||
Formats intermediate inputs and outputs of a block into a string representation.
|
||||
|
||||
Args:
|
||||
intermediates_inputs: List of intermediate input parameters
|
||||
required_intermediates_inputs: List of required intermediate input names
|
||||
intermediates_outputs: List of intermediate output parameters
|
||||
intermediate_inputs: List of intermediate input parameters
|
||||
required_intermediate_inputs: List of required intermediate input names
|
||||
intermediate_outputs: List of intermediate output parameters
|
||||
|
||||
Returns:
|
||||
str: Formatted string like:
|
||||
@@ -342,8 +348,8 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
||||
"""
|
||||
# Handle inputs
|
||||
input_parts = []
|
||||
for inp in intermediates_inputs:
|
||||
if inp.name in required_intermediates_inputs:
|
||||
for inp in intermediate_inputs:
|
||||
if inp.name in required_intermediate_inputs:
|
||||
input_parts.append(f"Required({inp.name})")
|
||||
else:
|
||||
if inp.name is None and inp.kwargs_type is not None:
|
||||
@@ -353,11 +359,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
|
||||
input_parts.append(inp_name)
|
||||
|
||||
# Handle modified variables (appear in both inputs and outputs)
|
||||
inputs_set = {inp.name for inp in intermediates_inputs}
|
||||
inputs_set = {inp.name for inp in intermediate_inputs}
|
||||
modified_parts = []
|
||||
new_output_parts = []
|
||||
|
||||
for out in intermediates_outputs:
|
||||
for out in intermediate_outputs:
|
||||
if out.name in inputs_set:
|
||||
modified_parts.append(out.name)
|
||||
else:
|
||||
@@ -575,7 +581,7 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediates_inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -587,7 +593,7 @@ def make_doc_string(
|
||||
|
||||
Args:
|
||||
inputs: List of input parameters
|
||||
intermediates_inputs: List of intermediate input parameters
|
||||
intermediate_inputs: List of intermediate input parameters
|
||||
outputs: List of output parameters
|
||||
description (str, *optional*): Description of the block
|
||||
class_name (str, *optional*): Name of the class to include in the documentation
|
||||
@@ -621,7 +627,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
@@ -382,7 +382,7 @@ class ModularNode(ConfigMixin):
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
|
||||
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
@@ -455,9 +455,9 @@ class ModularNode(ConfigMixin):
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.blocks.keys())[-1]
|
||||
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
|
||||
outputs = self.blocks.blocks[last_block_name].intermediate_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediates_outputs
|
||||
outputs = self.blocks.intermediate_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
@@ -495,9 +495,9 @@ class ModularNode(ConfigMixin):
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components, collection=None):
|
||||
self.blocks.setup_loader(component_manager=components, collection=collection)
|
||||
self._components_manager = components
|
||||
def setup(self, components_manager, collection=None):
|
||||
self.blocks.setup_loader(components_manager=components_manager, collection=collection)
|
||||
self._components_manager = components_manager
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
|
||||
@@ -28,12 +28,13 @@ else:
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"IP_ADAPTER_BLOCKS",
|
||||
"SDXL_SUPPORTED_BLOCKS",
|
||||
"ALL_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLAutoDecodeStep",
|
||||
"StableDiffusionXLAutoIPAdapterStep",
|
||||
"StableDiffusionXLAutoVaeEncoderStep",
|
||||
"StableDiffusionXLAutoControlnetStep",
|
||||
]
|
||||
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
|
||||
|
||||
@@ -53,12 +54,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
IP_ADAPTER_BLOCKS,
|
||||
SDXL_SUPPORTED_BLOCKS,
|
||||
ALL_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLAutoControlnetStep,
|
||||
)
|
||||
from .modular_loader import StableDiffusionXLModularLoader
|
||||
else:
|
||||
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
@@ -251,7 +251,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
@@ -423,7 +423,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -434,7 +434,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
@@ -565,7 +565,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
@@ -642,7 +642,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
@@ -678,7 +678,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
@@ -928,7 +928,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
@@ -953,7 +953,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
@@ -1009,7 +1009,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
@@ -1022,7 +1022,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
@@ -1124,7 +1124,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -1147,7 +1147,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"add_time_ids",
|
||||
@@ -1328,7 +1328,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -1351,7 +1351,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"add_time_ids",
|
||||
@@ -1510,7 +1510,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -1538,7 +1538,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"),
|
||||
OutputParam(
|
||||
@@ -1730,7 +1730,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -1764,7 +1764,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"),
|
||||
OutputParam(
|
||||
|
||||
@@ -59,7 +59,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -70,7 +70,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
@@ -170,30 +170,28 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
required=True,
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images from the decode step",
|
||||
),
|
||||
InputParam(
|
||||
"crops_coords",
|
||||
required=True,
|
||||
type_hint=Tuple[int, int],
|
||||
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -174,7 +174,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
@@ -280,7 +280,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
@@ -473,13 +473,13 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
# YiYi TODO: move this out of here
|
||||
@@ -545,7 +545,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[str]:
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
@@ -572,7 +572,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@staticmethod
|
||||
@@ -660,7 +660,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediates_inputs(self) -> List[InputParam]:
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -63,8 +63,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
|
||||
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||
"IP Adapter step that prepares ip adapter image embeddings.\n"
|
||||
"Note that this step only prepares the embeddings - in order for it to work correctly, "
|
||||
"you need to load ip adapter weights into unet via ModularPipeline.loader.\n"
|
||||
"e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().\n"
|
||||
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
||||
" for more details"
|
||||
)
|
||||
|
||||
@@ -99,7 +102,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||
OutputParam(
|
||||
@@ -251,7 +254,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
@@ -602,7 +605,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
@@ -614,7 +617,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
@@ -727,14 +730,14 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_inputs(self) -> List[InputParam]:
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[OutputParam]:
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
|
||||
@@ -844,6 +847,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.height is None:
|
||||
block_state.height = components.default_height
|
||||
if block_state.width is None:
|
||||
block_state.width = components.default_width
|
||||
|
||||
if block_state.padding_mask_crop is not None:
|
||||
block_state.crops_coords = components.mask_processor.get_crop_region(
|
||||
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
|
||||
|
||||
@@ -68,7 +68,7 @@ class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# optional ip-adapter (run before before_denoise)
|
||||
# optional ip-adapter (run before input step)
|
||||
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLIPAdapterStep]
|
||||
block_names = ["ip_adapter"]
|
||||
@@ -76,7 +76,9 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Run IP Adapter step if `ip_adapter_image` is provided."
|
||||
return (
|
||||
"Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
@@ -370,7 +372,7 @@ AUTO_BLOCKS = InsertableOrderedDict(
|
||||
)
|
||||
|
||||
|
||||
SDXL_SUPPORTED_BLOCKS = {
|
||||
ALL_BLOCKS = {
|
||||
"text2img": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
|
||||
@@ -44,6 +44,16 @@ class StableDiffusionXLModularLoader(
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
ModularIPAdapterMixin,
|
||||
):
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
default_sample_size = 128
|
||||
|
||||
Reference in New Issue
Block a user