1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

intermediates_inputs -> intermediate_inputs; component_manager -> components_manager, and more

This commit is contained in:
yiyixuxu
2025-06-27 12:48:30 +02:00
parent 7608d2eb9e
commit f63d62e091
11 changed files with 222 additions and 232 deletions

View File

@@ -284,12 +284,12 @@ class ComponentsManager:
if comp == component:
comp_name = self._id_to_name(comp_id)
if comp_name == name:
logger.warning(f"component '{name}' already exists as '{comp_id}'")
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
component_id = comp_id
break
else:
logger.warning(
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
@@ -301,7 +301,7 @@ class ComponentsManager:
if components_with_same_load_id:
existing = ", ".join(components_with_same_load_id)
logger.warning(
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
@@ -315,12 +315,12 @@ class ComponentsManager:
if component_id not in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}")
self.remove(comp_id)
self.collections[collection].add(component_id)
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}")
else:
logger.info(f"Added component '{name}' as '{component_id}'")
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
@@ -659,6 +659,10 @@ class ComponentsManager:
return info
def __repr__(self):
# Handle empty components case
if not self.components:
return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50
# Helper to get simple name without UUID
def get_simple_name(name):
# Extract the base name by splitting on underscore and taking first part
@@ -802,51 +806,6 @@ class ComponentsManager:
return output
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
"""
Load components from a pretrained model and add them to the manager.
Args:
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
prefix (str, optional): Prefix to add to all component names loaded from this model.
If provided, components will be named as "{prefix}_{component_name}"
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
subfolder = kwargs.pop("subfolder", None)
# YiYi TODO: extend AutoModel to support non-diffusers models
if subfolder:
from ..models import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
else:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def get_one(
self,
component_id: Optional[str] = None,

View File

@@ -126,7 +126,7 @@ class PipelineState:
input_names = self.input_kwargs.get(kwargs_type, [])
return self.get_inputs(input_names)
def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
"""
Get all intermediates with matching kwargs_type.
@@ -325,7 +325,7 @@ class ModularPipelineBlocks(ConfigMixin):
def init_pipeline(
self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
component_manager: Optional[ComponentsManager] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
):
"""
@@ -344,10 +344,10 @@ class ModularPipelineBlocks(ConfigMixin):
loader = loader_class(
specs=specs,
pretrained_model_name_or_path=pretrained_model_name_or_path,
component_manager=component_manager,
components_manager=components_manager,
collection=collection,
)
modular_pipeline = ModularPipeline(blocks=self, loader=loader)
modular_pipeline = ModularPipeline(blocks=deepcopy(self), loader=loader)
return modular_pipeline
@@ -374,17 +374,17 @@ class PipelineBlock(ModularPipelineBlocks):
return []
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
def _get_outputs(self):
return self.intermediates_outputs
return self.intermediate_outputs
# YiYi TODO: is it too easy for user to unintentionally override these properties?
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
@@ -403,9 +403,9 @@ class PipelineBlock(ModularPipelineBlocks):
def required_inputs(self) -> List[str]:
return self._get_required_inputs()
def _get_required_intermediates_inputs(self):
def _get_required_intermediate_inputs(self):
input_names = []
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@@ -413,8 +413,8 @@ class PipelineBlock(ModularPipelineBlocks):
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
return self._get_required_intermediates_inputs()
def required_intermediate_inputs(self) -> List[str]:
return self._get_required_intermediate_inputs()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
raise NotImplementedError("__call__ method must be implemented in subclasses")
@@ -449,7 +449,7 @@ class PipelineBlock(ModularPipelineBlocks):
# Intermediates section
intermediates_str = format_intermediates_short(
self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
)
intermediates = f"Intermediates:\n{intermediates_str}"
@@ -459,7 +459,7 @@ class PipelineBlock(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -492,7 +492,7 @@ class PipelineBlock(ModularPipelineBlocks):
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
@@ -503,9 +503,9 @@ class PipelineBlock(ModularPipelineBlocks):
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
if intermediates_kwargs:
for k, v in intermediates_kwargs.items():
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
@@ -513,13 +513,13 @@ class PipelineBlock(ModularPipelineBlocks):
return BlockState(**data)
def add_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediates_outputs:
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
@@ -527,7 +527,7 @@ class PipelineBlock(ModularPipelineBlocks):
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
@@ -537,8 +537,8 @@ class PipelineBlock(ModularPipelineBlocks):
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediates_kwargs.items():
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.add_intermediate(param_name, param, input_param.kwargs_type)
@@ -610,6 +610,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
return list(combined_dict.values())
# YiYi TODO: change blocks attribute to a different name, so it is not confused with the blocks attribute in ModularPipeline
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
A class that automatically selects a block to run based on the inputs.
@@ -692,15 +693,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
def required_intermediate_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
return []
first_block = next(iter(self.blocks.values()))
required_by_all = set(getattr(first_block, "required_intermediates_inputs", set()))
required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
# Intersect with required inputs from all other blocks
for block in list(self.blocks.values())[1:]:
block_required = set(getattr(block, "required_intermediates_inputs", set()))
block_required = set(getattr(block, "required_intermediate_inputs", set()))
required_by_all.intersection_update(block_required)
return list(required_by_all)
@@ -719,20 +720,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return combined_inputs
@property
def intermediates_inputs(self) -> List[str]:
named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()]
def intermediate_inputs(self) -> List[str]:
named_inputs = [(name, block.intermediate_inputs) for name, block in self.blocks.items()]
combined_inputs = combine_inputs(*named_inputs)
# mark Required inputs only if that input is required by all the blocks
for input_param in combined_inputs:
if input_param.name in self.required_intermediates_inputs:
if input_param.name in self.required_intermediate_inputs:
input_param.required = True
else:
input_param.required = False
return combined_inputs
@property
def intermediates_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
@@ -885,7 +886,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -975,12 +976,12 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
required_intermediates_inputs = []
for input_param in self.intermediates_inputs:
def required_intermediate_inputs(self) -> List[str]:
required_intermediate_inputs = []
for input_param in self.intermediate_inputs:
if input_param.required:
required_intermediates_inputs.append(input_param.name)
return required_intermediates_inputs
required_intermediate_inputs.append(input_param.name)
return required_intermediate_inputs
# YiYi TODO: add test for this
@property
@@ -999,10 +1000,10 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return combined_inputs
@property
def intermediates_inputs(self) -> List[str]:
return self.get_intermediates_inputs()
def intermediate_inputs(self) -> List[str]:
return self.get_intermediate_inputs()
def get_intermediates_inputs(self):
def get_intermediate_inputs(self):
inputs = []
outputs = set()
added_inputs = set()
@@ -1010,7 +1011,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# Go through all blocks in order
for block in self.blocks.values():
# Add inputs that aren't in outputs yet
for inp in block.intermediates_inputs:
for inp in block.intermediate_inputs:
if inp.name not in outputs and inp.name not in added_inputs:
inputs.append(inp)
added_inputs.add(inp.name)
@@ -1022,27 +1023,27 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
if should_add_outputs:
# Add this block's outputs
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
outputs.update(block_intermediates_outputs)
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
return inputs
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
named_outputs = []
for name, block in self.blocks.items():
inp_names = {inp.name for inp in block.intermediates_inputs}
# so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
# filter out them here so they do not end up as intermediates_outputs
inp_names = {inp.name for inp in block.intermediate_inputs}
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
# filter out them here so they do not end up as intermediate_outputs
if name not in inp_names:
named_outputs.append((name, block.intermediates_outputs))
named_outputs.append((name, block.intermediate_outputs))
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# YiYi TODO: I think we can remove the outputs property
@property
def outputs(self) -> List[str]:
# return next(reversed(self.blocks.values())).intermediates_outputs
return self.intermediates_outputs
# return next(reversed(self.blocks.values())).intermediate_outputs
return self.intermediate_outputs
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
@@ -1248,7 +1249,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -1287,12 +1288,12 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return []
@property
def loop_intermediates_inputs(self) -> List[InputParam]:
def loop_intermediate_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def loop_intermediates_outputs(self) -> List[OutputParam]:
def loop_intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
@@ -1305,9 +1306,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return input_names
@property
def loop_required_intermediates_inputs(self) -> List[str]:
def loop_required_intermediate_inputs(self) -> List[str]:
input_names = []
for input_param in self.loop_intermediates_inputs:
for input_param in self.loop_intermediate_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@@ -1356,25 +1357,25 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def inputs(self):
return self.get_inputs()
# modified from SequentialPipelineBlocks to include loop_intermediates_inputs
# modified from SequentialPipelineBlocks to include loop_intermediate_inputs
@property
def intermediates_inputs(self):
intermediates = self.get_intermediates_inputs()
def intermediate_inputs(self):
intermediates = self.get_intermediate_inputs()
intermediate_names = [input.name for input in intermediates]
for loop_intermediate_input in self.loop_intermediates_inputs:
for loop_intermediate_input in self.loop_intermediate_inputs:
if loop_intermediate_input.name not in intermediate_names:
intermediates.append(loop_intermediate_input)
return intermediates
# modified from SequentialPipelineBlocks
def get_intermediates_inputs(self):
def get_intermediate_inputs(self):
inputs = []
outputs = set()
# Go through all blocks in order
for block in self.blocks.values():
# Add inputs that aren't in outputs yet
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
@@ -1383,8 +1384,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
if should_add_outputs:
# Add this block's outputs
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
outputs.update(block_intermediates_outputs)
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
return inputs
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
@@ -1407,23 +1408,23 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediates_inputs(self) -> List[str]:
required_intermediates_inputs = []
for input_param in self.intermediates_inputs:
def required_intermediate_inputs(self) -> List[str]:
required_intermediate_inputs = []
for input_param in self.intermediate_inputs:
if input_param.required:
required_intermediates_inputs.append(input_param.name)
for input_param in self.loop_intermediates_inputs:
required_intermediate_inputs.append(input_param.name)
for input_param in self.loop_intermediate_inputs:
if input_param.required:
required_intermediates_inputs.append(input_param.name)
return required_intermediates_inputs
required_intermediate_inputs.append(input_param.name)
return required_intermediate_inputs
# YiYi TODO: this need to be thought about more
# modified from SequentialPipelineBlocks to include loop_intermediates_outputs
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs
@property
def intermediates_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
for output in self.loop_intermediates_outputs:
for output in self.loop_intermediate_outputs:
if output.name not in {output.name for output in combined_outputs}:
combined_outputs.append(output)
return combined_outputs
@@ -1431,7 +1432,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO: this need to be thought about more
@property
def outputs(self) -> List[str]:
return next(reversed(self.blocks.values())).intermediates_outputs
return next(reversed(self.blocks.values())).intermediate_outputs
def __init__(self):
blocks = InsertableOrderedDict()
@@ -1497,7 +1498,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
@@ -1508,9 +1509,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
if intermediates_kwargs:
for k, v in intermediates_kwargs.items():
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
@@ -1518,13 +1519,13 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return BlockState(**data)
def add_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediates_outputs:
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediates_inputs:
for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
@@ -1534,8 +1535,8 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediates_kwargs.items():
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
@@ -1546,7 +1547,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
self.intermediates_inputs,
self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -1660,7 +1661,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
- non from_pretrained components are created during __init__ and registered as the object itself
- Components are updated with the `update()` method: e.g. loader.update(unet=unet) or
loader.update(guider=guider_spec)
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"])
- (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"])
Args:
**kwargs: Keyword arguments where keys are component names and values are component objects.
@@ -1710,8 +1711,8 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
if not is_registered:
self.register_to_config(**register_dict)
setattr(self, name, module)
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
self._component_manager.add(name, module, self._collection)
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
self._components_manager.add(name, module, self._collection)
continue
current_module = getattr(self, name, None)
@@ -1745,22 +1746,22 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
# finally set models
setattr(self, name, module)
# add to component manager if one is attached
if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None:
self._component_manager.add(name, module, self._collection)
if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None:
self._components_manager.add(name, module, self._collection)
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
def __init__(
self,
specs: List[Union[ComponentSpec, ConfigSpec]],
pretrained_model_name_or_path: Optional[str] = None,
component_manager: Optional[ComponentsManager] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
**kwargs,
):
"""
Initialize the loader with a list of component specs and config specs.
"""
self._component_manager = component_manager
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)}
self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)}
@@ -1848,6 +1849,10 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
return module.dtype
return torch.float32
@property
def component_names(self) -> List[str]:
return list(self.components.keys())
@property
def components(self) -> Dict[str, Any]:
@@ -1958,12 +1963,12 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
self.register_to_config(**config_to_register)
# YiYi TODO: support map for additional from_pretrained kwargs
def load(self, component_names: Optional[List[str]] = None, **kwargs):
def load(self, names: Optional[List[str]] = None, **kwargs):
"""
Load selectedcomponents from specs.
Load selected components from specs.
Args:
component_names: List of component names to load
names: List of component names to load
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
@@ -1971,19 +1976,19 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
`variant`, `revision`, etc.
"""
# if not specific name, load all the components with default_creation_method == "from_pretrained"
if component_names is None:
component_names = [
if names is None:
names = [
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
]
elif not isinstance(component_names, list):
component_names = [component_names]
elif not isinstance(names, list):
names = [names]
components_to_load = {name for name in component_names if name in self._component_specs}
unknown_component_names = {name for name in component_names if name not in self._component_specs}
if len(unknown_component_names) > 0:
logger.warning(f"Unknown components will be ignored: {unknown_component_names}")
components_to_load = {name for name in names if name in self._component_specs}
unknown_names = {name for name in names if name not in self._component_specs}
if len(unknown_names) > 0:
logger.warning(f"Unknown components will be ignored: {unknown_names}")
components_to_register = {}
for name in components_to_load:
@@ -2240,7 +2245,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
spec_only: bool = True,
component_manager: Optional[ComponentsManager] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
**kwargs,
):
@@ -2261,7 +2266,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
elif name in expected_config:
config_specs.append(ConfigSpec(name=name, default=value))
return cls(component_specs + config_specs, component_manager=component_manager, collection=collection)
return cls(component_specs + config_specs, components_manager=components_manager, collection=collection)
@staticmethod
def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
@@ -2370,20 +2375,20 @@ class ModularPipeline:
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediates_inputs = [inp.name for inp in self.blocks.intermediates_inputs]
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediates_inputs:
if name not in intermediate_inputs:
state.add_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.add_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.add_input(name, default, kwargs_type)
for expected_intermediate_param in self.blocks.intermediates_inputs:
for expected_intermediate_param in self.blocks.intermediate_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
@@ -2412,8 +2417,8 @@ class ModularPipeline:
else:
raise ValueError(f"Output '{output}' is not a valid output type")
def load_components(self, component_names: Optional[List[str]] = None, **kwargs):
self.loader.load(component_names=component_names, **kwargs)
def load_components(self, names: Optional[List[str]] = None, **kwargs):
self.loader.load(names=names, **kwargs)
def update_components(self, **kwargs):
self.loader.update(**kwargs)
@@ -2424,7 +2429,7 @@ class ModularPipeline:
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
trust_remote_code: Optional[bool] = None,
component_manager: Optional[ComponentsManager] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
**kwargs,
):
@@ -2432,7 +2437,7 @@ class ModularPipeline:
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
pipeline = blocks.init_pipeline(
pretrained_model_name_or_path, component_manager=component_manager, collection=collection, **kwargs
pretrained_model_name_or_path, components_manager=components_manager, collection=collection, **kwargs
)
return pipeline

View File

@@ -49,7 +49,13 @@ class InsertableOrderedDict(OrderedDict):
items = []
for i, (key, value) in enumerate(self.items()):
items.append(f"{i}: ({repr(key)}, {repr(value)})")
if isinstance(value, type):
# For classes, show class name and <class ...>
obj_repr = f"<class '{value.__module__}.{value.__name__}'>"
else:
# For objects (instances) and other types, show class name and module
obj_repr = f"<obj '{value.__class__.__module__}.{value.__class__.__name__}'>"
items.append(f"{i}: ({repr(key)}, {obj_repr})")
return "InsertableOrderedDict([\n " + ",\n ".join(items) + "\n])"
@@ -260,11 +266,11 @@ class ConfigSpec:
description: Optional[str] = None
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
# however some fields are not relevant for intermediates_inputs
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
# -> should we use different class for inputs and intermediates_inputs?
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
# -> should we use different class for inputs and intermediate_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
@@ -324,14 +330,14 @@ def format_inputs_short(inputs):
return inputs_str
def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs):
def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
"""
Formats intermediate inputs and outputs of a block into a string representation.
Args:
intermediates_inputs: List of intermediate input parameters
required_intermediates_inputs: List of required intermediate input names
intermediates_outputs: List of intermediate output parameters
intermediate_inputs: List of intermediate input parameters
required_intermediate_inputs: List of required intermediate input names
intermediate_outputs: List of intermediate output parameters
Returns:
str: Formatted string like:
@@ -342,8 +348,8 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
"""
# Handle inputs
input_parts = []
for inp in intermediates_inputs:
if inp.name in required_intermediates_inputs:
for inp in intermediate_inputs:
if inp.name in required_intermediate_inputs:
input_parts.append(f"Required({inp.name})")
else:
if inp.name is None and inp.kwargs_type is not None:
@@ -353,11 +359,11 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}
inputs_set = {inp.name for inp in intermediate_inputs}
modified_parts = []
new_output_parts = []
for out in intermediates_outputs:
for out in intermediate_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
@@ -575,7 +581,7 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string(
inputs,
intermediates_inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
@@ -587,7 +593,7 @@ def make_doc_string(
Args:
inputs: List of input parameters
intermediates_inputs: List of intermediate input parameters
intermediate_inputs: List of intermediate input parameters
outputs: List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
@@ -621,7 +627,7 @@ def make_doc_string(
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
# Add outputs section
output += "\n\n"

View File

@@ -382,7 +382,7 @@ class ModularNode(ConfigMixin):
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
@@ -455,9 +455,9 @@ class ModularNode(ConfigMixin):
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
outputs = self.blocks.blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediates_outputs
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
@@ -495,9 +495,9 @@ class ModularNode(ConfigMixin):
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
def setup(self, components_manager, collection=None):
self.blocks.setup_loader(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):

View File

@@ -28,12 +28,13 @@ else:
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"IP_ADAPTER_BLOCKS",
"SDXL_SUPPORTED_BLOCKS",
"ALL_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLAutoDecodeStep",
"StableDiffusionXLAutoIPAdapterStep",
"StableDiffusionXLAutoVaeEncoderStep",
"StableDiffusionXLAutoControlnetStep",
]
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
@@ -53,12 +54,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
SDXL_SUPPORTED_BLOCKS,
ALL_BLOCKS,
TEXT2IMAGE_BLOCKS,
StableDiffusionXLAutoBlocks,
StableDiffusionXLAutoDecodeStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLAutoControlnetStep,
)
from .modular_loader import StableDiffusionXLModularLoader
else:

View File

@@ -215,7 +215,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
@@ -251,7 +251,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
@@ -423,7 +423,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"batch_size",
@@ -434,7 +434,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
@@ -565,7 +565,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
@@ -642,7 +642,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
@@ -678,7 +678,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
@@ -928,7 +928,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
@@ -953,7 +953,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
@@ -1009,7 +1009,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
@@ -1022,7 +1022,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
@@ -1124,7 +1124,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
@@ -1147,7 +1147,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"add_time_ids",
@@ -1328,7 +1328,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
@@ -1351,7 +1351,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"add_time_ids",
@@ -1510,7 +1510,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -1538,7 +1538,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"),
OutputParam(
@@ -1730,7 +1730,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
@@ -1764,7 +1764,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"),
OutputParam(

View File

@@ -59,7 +59,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -70,7 +70,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
@@ -170,30 +170,28 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"images",
required=True,
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images from the decode step",
),
InputParam(
"crops_coords",
required=True,
type_hint=Tuple[int, int],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
),
]
@property
def intermediates_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",

View File

@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -174,7 +174,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"num_inference_steps",
@@ -280,7 +280,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
@@ -473,13 +473,13 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
# YiYi TODO: move this out of here
@@ -545,7 +545,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[str]:
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
@@ -572,7 +572,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@staticmethod
@@ -660,7 +660,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
def loop_intermediates_inputs(self) -> List[InputParam]:
def loop_intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",

View File

@@ -63,8 +63,11 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
@property
def description(self) -> str:
return (
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
"IP Adapter step that prepares ip adapter image embeddings.\n"
"Note that this step only prepares the embeddings - in order for it to work correctly, "
"you need to load ip adapter weights into unet via ModularPipeline.loader.\n"
"e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().\n"
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@@ -99,7 +102,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam(
@@ -251,7 +254,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
@@ -602,7 +605,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
@@ -614,7 +617,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
@@ -727,14 +730,14 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
]
@property
def intermediates_inputs(self) -> List[InputParam]:
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
@@ -844,6 +847,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.height is None:
block_state.height = components.default_height
if block_state.width is None:
block_state.width = components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop

View File

@@ -68,7 +68,7 @@ class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
)
# optional ip-adapter (run before before_denoise)
# optional ip-adapter (run before input step)
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
@@ -76,7 +76,9 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."
return (
"Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
)
# before_denoise: text2img
@@ -370,7 +372,7 @@ AUTO_BLOCKS = InsertableOrderedDict(
)
SDXL_SUPPORTED_BLOCKS = {
ALL_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,

View File

@@ -44,6 +44,16 @@ class StableDiffusionXLModularLoader(
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
default_sample_size = 128