diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py deleted file mode 100644 index 97a8677bda..0000000000 --- a/src/diffusers/pipelines/modular_pipeline.py +++ /dev/null @@ -1,1916 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - - -import torch -from tqdm.auto import tqdm -import re -import os -import importlib - -from huggingface_hub.utils import validate_hf_hub_args - -from ..configuration_utils import ConfigMixin, FrozenDict -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, - PushToHubMixin, -) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple -from .modular_pipeline_utils import ( - ComponentSpec, - ConfigSpec, - InputParam, - OutputParam, - format_components, - format_configs, - format_input_params, - format_inputs_short, - format_intermediates_short, - format_output_params, - format_params, - make_doc_string, -) -from .components_manager import ComponentsManager - -from copy import deepcopy -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -MODULAR_LOADER_MAPPING = OrderedDict( - [ - ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), - ] -) - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) - - def add_input(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an input to the pipeline state with optional metadata. - - Args: - key (str): The key for the input - value (Any): The input value - kwargs_type (str): The kwargs_type to store with the input - """ - self.inputs[key] = value - if kwargs_type is not None: - if kwargs_type not in self.input_kwargs: - self.input_kwargs[kwargs_type] = [key] - else: - self.input_kwargs[kwargs_type].append(key) - - def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): - """ - Add an intermediate value to the pipeline state with optional metadata. - - Args: - key (str): The key for the intermediate value - value (Any): The intermediate value - kwargs_type (str): The kwargs_type to store with the intermediate value - """ - self.intermediates[key] = value - if kwargs_type is not None: - if kwargs_type not in self.intermediate_kwargs: - self.intermediate_kwargs[kwargs_type] = [key] - else: - self.intermediate_kwargs[kwargs_type].append(key) - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.inputs.get(key, default) for key in keys} - - def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all inputs with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of inputs with matching kwargs_type - """ - input_names = self.input_kwargs.get(kwargs_type, []) - return self.get_inputs(input_names) - - def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all intermediates with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of intermediates with matching kwargs_type - """ - intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) - return self.get_intermediates(intermediate_names) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.intermediates.get(key, default) for key in keys} - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - - # Format input_kwargs and intermediate_kwargs - input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) - intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" input_kwargs={{\n{input_kwargs_str}\n }},\n" - f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __getitem__(self, key: str): - # allows block_state["foo"] - return getattr(self, key, None) - - def __setitem__(self, key: str, value: Any): - # allows block_state["foo"] = "bar" - setattr(self, key, value) - - def as_dict(self): - """ - Convert BlockState to a dictionary. - - Returns: - Dict[str, Any]: Dictionary containing all attributes of the BlockState - """ - return {key: value for key, value in self.__dict__.items()} - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - formatted_dict = {} - for k, val in v.items(): - if hasattr(val, "shape") and hasattr(val, "dtype"): - formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" - elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): - shapes = [t.shape for t in val] - formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" - else: - formatted_dict[k] = repr(val) - return formatted_dict - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - - -class ModularPipelineMixin: - """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks - """ - - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): - """ - create a mouldar loader, optionally accept modular_repo to load from hub. - """ - - # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] - diffusers_module = importlib.import_module("diffusers") - loader_class = getattr(diffusers_module, loader_class_name) - - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - # Create the loader with the updated specs - specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) - - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - if not hasattr(self, "loader"): - logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") - self.loader = None - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for expected_input_param in self.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediates_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.intermediates_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - @torch.compiler.disable - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - -class PipelineBlock(ModularPipelineMixin): - - model_name = None - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - - - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs - - @property - def required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - use format_components with add_empty_lines=False - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - components = " " + components_str.replace("\n", "\n ") - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - configs = " " + configs_str.replace("\n", "\n ") - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f"{components}\n" - f"{configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name is None and input_param.kwargs_type is not None: - input_name = "*_" + input_param.kwargs_type - else: - input_name = input_param.name - if input_name in combined_dict: - current_param = combined_dict[input_name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - -class AutoPipelineBlocks(ModularPipelineMixin): - """ - A class that automatically selects a block to run based on the inputs. - - Attributes: - block_classes: List of block classes to be used - block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_names = [] - block_trigger_inputs = [] - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): - raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - - @property - def required_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - @property - def required_intermediates_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - - @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - elif input_name is not None and state.get_intermediate(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.warning(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - - return trigger_inputs - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -class SequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. - """ - block_classes = [] - block_names = [] - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": - """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - - Args: - blocks_dict: Dictionary mapping block names to block instances - - Returns: - A new SequentialPipelineBlocks instance - """ - instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict - return instance - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - return self.get_inputs() - - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_inputs(self) -> List[str]: - return self.get_intermediates_inputs() - - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): - try: - pipeline, state = block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return pipeline, state - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - return fn_recursive_get_trigger(self.blocks) - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) - def fn_recursive_traverse(block, block_name, active_triggers): - result_blocks = OrderedDict() - - # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): - # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} - result_blocks.update(blocks_to_update) - else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): - active_triggers.update(out.name for out in block.outputs) - return result_blocks - - # auto - else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - matching_trigger = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - matching_trigger = None - - if this_block is not None: - # sequential/auto (keep traversing) - if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, 'outputs'): - active_triggers.update(out.name for out in this_block.outputs) - - return result_blocks - - all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - all_blocks.update(blocks_to_update) - return all_blocks - - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs - - if trigger_inputs is not None: - - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) - return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - focus only on expected components - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - - # Blocks section - moved to the end with simplified format - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n\n" - - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - - - @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs - ) - -#YiYi TODO: __repr__ -class LoopSequentialPipelineBlocks(ModularPipelineMixin): - """ - A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. - """ - - model_name = None - block_classes = [] - block_names = [] - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def loop_expected_configs(self) -> List[ConfigSpec]: - return [] - - @property - def loop_inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - - @property - def loop_required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def loop_required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - # modified from SequentialPipelineBlocks to include loop_expected_components - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - for component in self.loop_expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - # modified from SequentialPipelineBlocks to include loop_expected_configs - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - for config in self.loop_expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - # modified from SequentialPipelineBlocks to include loop_inputs - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - named_inputs.append(("loop", self.loop_inputs)) - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - # Copied from SequentialPipelineBlocks - @property - def inputs(self): - return self.get_inputs() - - - # modified from SequentialPipelineBlocks to include loop_intermediates_inputs - @property - def intermediates_inputs(self): - intermediates = self.get_intermediates_inputs() - intermediate_names = [input.name for input in intermediates] - for loop_intermediate_input in self.loop_intermediates_inputs: - if loop_intermediate_input.name not in intermediate_names: - intermediates.append(loop_intermediate_input) - return intermediates - - - # Copied from SequentialPipelineBlocks - def get_intermediates_inputs(self): - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - - # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - required_by_loop = set(getattr(self, "loop_required_inputs", set())) - required_by_any.update(required_by_loop) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - # modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - for input_param in self.loop_intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - - # YiYi TODO: this need to be thought about more - # modified from SequentialPipelineBlocks to include loop_intermediates_outputs - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - for output in self.loop_intermediates_outputs: - if output.name not in set([output.name for output in combined_outputs]): - combined_outputs.append(output) - return combined_outputs - - # YiYi TODO: this need to be thought about more - # copied from SequentialPipelineBlocks - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - def loop_step(self, components, state: PipelineState, **kwargs): - - for block_name, block in self.blocks.items(): - try: - components, state = block(components, state, **kwargs) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return components, state - - def __call__(self, components, state: PipelineState) -> PipelineState: - raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediates_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) - if intermediates_kwargs: - for k, v in intermediates_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) - -# YiYi TODO: -# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) -# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader -# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() -class ModularLoader(ConfigMixin, PushToHubMixin): - """ - Base class for all Modular pipelines loaders. - - """ - config_name = "modular_model_index.json" - - - def register_components(self, **kwargs): - """ - Register components with their corresponding specs. - This method is called when component changed or __init__ is called. - - Args: - **kwargs: Keyword arguments where keys are component names and values are component objects. - - """ - for name, module in kwargs.items(): - - # current component spec - component_spec = self._component_specs.get(name) - if component_spec is None: - logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") - continue - - is_registered = hasattr(self, name) - - if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - # actual library and class name of the module - - if module is not None: - library, class_name = _fetch_class_library_tuple(module) - new_component_spec = ComponentSpec.from_component(name, module) - component_spec_dict = self._component_spec_to_dict(new_component_spec) - - else: - library, class_name = None, None - # if module is None, we do not update the spec, - # but we still need to update the config to make sure it's synced with the component spec - # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) - new_component_spec = component_spec - component_spec_dict = self._component_spec_to_dict(component_spec) - - # do not register if component is not to be loaded from pretrained - if new_component_spec.default_creation_method == "from_pretrained": - register_dict = {name: (library, class_name, component_spec_dict)} - else: - register_dict = {} - - # set the component as attribute - # if it is not set yet, just set it and skip the process to check and warn below - if not is_registered: - self.register_to_config(**register_dict) - self._component_specs[name] = new_component_spec - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - continue - - current_module = getattr(self, name, None) - # skip if the component is already registered with the same object - if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") - continue - - # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") - - # warn if unregister - if current_module is not None and module is None: - logger.info( - f"ModularLoader.register_components: setting '{name}' to None " - f"(was {current_module.__class__.__name__})" - ) - # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: - logger.debug( - f"ModularLoader.register_components: replacing existing '{name}' " - f"(same type {type(current_module).__name__}, new instance)" - ) - - # save modular_model_index.json config - self.register_to_config(**register_dict) - # update component spec - self._component_specs[name] = new_component_spec - # finally set models - setattr(self, name, module) - if module is not None and self._component_manager is not None: - self._component_manager.add(name, module, self._collection) - - - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): - """ - Initialize the loader with a list of component specs and config specs. - """ - self._component_manager = component_manager - self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } - - # update component_specs and config_specs from modular_repo - if modular_repo is not None: - config_dict = self.load_config(modular_repo, **kwargs) - - for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - self._component_specs[name] = component_spec - - elif name in self._config_specs: - self._config_specs[name].default = value - - register_components_dict = {} - for name, component_spec in self._component_specs.items(): - register_components_dict[name] = None - self.register_components(**register_components_dict) - - default_configs = {} - for name, config_spec in self._config_specs.items(): - default_configs[name] = config_spec.default - self.register_to_config(**default_configs) - - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module): - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - - @property - def components(self) -> Dict[str, Any]: - # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } - - def update(self, **kwargs): - """ - Update components and configs after instance creation. - - Args: - - """ - """ - Update components and configuration values after the loader has been instantiated. - - This method allows you to: - 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) - 2. Update configuration values (e.g., changing requires_safety_checker flag) - - Args: - **kwargs: Component objects or configuration values to update: - - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) - - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) - - Raises: - ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) - - Examples: - ```python - # Update multiple components at once - loader.update( - unet=new_unet_model, - text_encoder=new_text_encoder - ) - - # Update configuration values - loader.update( - requires_safety_checker=False, - guidance_rescale=0.7 - ) - - # Update both components and configs together - loader.update( - unet=new_unet_model, - requires_safety_checker=False - ) - ``` - """ - - # extract component_specs_updates & config_specs_updates from `specs` - passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} - passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} - - for name, component in passed_components.items(): - if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - - if len(kwargs) > 0: - logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - - - self.register_components(**passed_components) - - - config_to_register = {} - for name, new_value in passed_config_values.items(): - - # e.g. requires_aesthetics_score = False - self._config_specs[name].default = new_value - config_to_register[name] = new_value - self.register_to_config(**config_to_register) - - - # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, component_names: Optional[List[str]] = None, **kwargs): - """ - Load selectedcomponents from specs. - - Args: - component_names: List of component names to load - **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. - """ - if component_names is None: - component_names = list(self._component_specs.keys()) - elif not isinstance(component_names, list): - component_names = [component_names] - - components_to_load = set([name for name in component_names if name in self._component_specs]) - unknown_component_names = set([name for name in component_names if name not in self._component_specs]) - if len(unknown_component_names) > 0: - logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - - components_to_register = {} - for name in components_to_load: - spec = self._component_specs[name] - component_load_kwargs = {} - for key, value in kwargs.items(): - if not isinstance(value, dict): - # if the value is a single value, apply it to all components - component_load_kwargs[key] = value - else: - if name in value: - # if it is a dict, check if the component name is in the dict - component_load_kwargs[key] = value[name] - elif "default" in value: - # check if the default is specified - component_load_kwargs[key] = value["default"] - try: - components_to_register[name] = spec.create(**component_load_kwargs) - except Exception as e: - logger.warning(f"Failed to create component '{name}': {e}") - - # Register all components at once - self.register_components(**components_to_register) - - # YiYi TODO: should support to method - def to(self, *args, **kwargs): - pass - - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - - component_names = list(self._component_specs.keys()) - config_names = list(self._config_specs.keys()) - self.register_to_config(_components_names=component_names, _configs_names=config_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - config.pop("_components_names", None) - config.pop("_configs_names", None) - self._internal_dict = FrozenDict(config) - - - @classmethod - @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - expected_component = set(config_dict.pop("_components_names")) - expected_config = set(config_dict.pop("_configs_names")) - - component_specs = [] - config_specs = [] - for name, value in config_dict.items(): - if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) - - elif name in expected_config: - config_specs.append(ConfigSpec(name=name, default=value)) - - for name in expected_component: - for spec in component_specs: - if spec.name == name: - break - else: - # append a empty component spec for these not in modular_model_index - component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) - return cls(component_specs + config_specs) - - - @staticmethod - def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: - """ - Convert a ComponentSpec into a JSON‐serializable dict for saving in - `modular_model_index.json`. - - This dict contains: - - "type_hint": Tuple[str, str] - The fully‐qualified module path and class name of the component. - - All loading fields defined by `component_spec.loading_fields()`, typically: - - "repo": Optional[str] - The model repository (e.g., "stabilityai/stable-diffusion-xl"). - - "subfolder": Optional[str] - A subfolder within the repo where this component lives. - - "variant": Optional[str] - An optional variant identifier for the model. - - "revision": Optional[str] - A specific git revision (commit hash, tag, or branch). - - ... any other loading fields defined on the spec. - - Args: - component_spec (ComponentSpec): - The spec object describing one pipeline component. - - Returns: - Dict[str, Any]: A mapping suitable for JSON serialization. - - Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec - >>> from diffusers.models.unet import UNet2DConditionModel - >>> spec = ComponentSpec( - ... name="unet", - ... type_hint=UNet2DConditionModel, - ... config=None, - ... repo="path/to/repo", - ... subfolder="subfolder", - ... variant=None, - ... revision=None, - ... default_creation_method="from_pretrained", - ... ) - >>> ModularLoader._component_spec_to_dict(spec) - { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": None, - "revision": None, - } - """ - if component_spec.type_hint is not None: - lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) - else: - lib_name = None - cls_name = None - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} - return { - "type_hint": (lib_name, cls_name), - **load_spec_dict, - } - - @staticmethod - def _dict_to_component_spec( - name: str, - spec_dict: Dict[str, Any], - ) -> ComponentSpec: - """ - Reconstruct a ComponentSpec from a dict. - """ - # make a shallow copy so we can pop() safely - spec_dict = spec_dict.copy() - # pull out and resolve the stored type_hint - lib_name, cls_name = spec_dict.pop("type_hint") - if lib_name is not None and cls_name is not None: - type_hint = simple_get_class_obj(lib_name, cls_name) - else: - type_hint = None - - # re‐assemble the ComponentSpec - return ComponentSpec( - name=name, - type_hint=type_hint, - **spec_dict, - ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py deleted file mode 100644 index 392d6dcd95..0000000000 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal - -from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin - -if is_torch_available(): - import torch - - -# YiYi TODO: -# 1. validate the dataclass fields -# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() -@dataclass -class ComponentSpec: - """Specification for a pipeline component. - - A component can be created in two ways: - 1. From scratch using __init__ with a config dict - 2. using `from_pretrained` - - Attributes: - name: Name of the component - type_hint: Type of the component (e.g. UNet2DConditionModel) - description: Optional description of the component - config: Optional config dict for __init__ creation - repo: Optional repo path for from_pretrained creation - subfolder: Optional subfolder in repo - variant: Optional variant in repo - revision: Optional revision in repo - default_creation_method: Preferred creation method - "from_config" or "from_pretrained" - """ - name: Optional[str] = None - type_hint: Optional[Type] = None - description: Optional[str] = None - config: Optional[FrozenDict[str, Any]] = None - # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name - repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) - subfolder: Optional[str] = field(default=None, metadata={"loading": True}) - variant: Optional[str] = field(default=None, metadata={"loading": True}) - revision: Optional[str] = field(default=None, metadata={"loading": True}) - default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - - def __hash__(self): - """Make ComponentSpec hashable, using load_id as the hash value.""" - return hash((self.name, self.load_id, self.default_creation_method)) - - def __eq__(self, other): - """Compare ComponentSpec objects based on name and load_id.""" - if not isinstance(other, ComponentSpec): - return False - return (self.name == other.name and - self.load_id == other.load_id and - self.default_creation_method == other.default_creation_method) - - @classmethod - def from_component(cls, name: str, component: torch.nn.Module) -> Any: - """Create a ComponentSpec from a Component created by `create` method.""" - - if not hasattr(component, "_diffusers_load_id"): - raise ValueError("Component is not created by `create` method") - - type_hint = component.__class__ - - if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): - config = component.config - else: - config = None - - load_spec = cls.decode_load_id(component._diffusers_load_id) - - return cls(name=name, type_hint=type_hint, config=config, **load_spec) - - @classmethod - def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: - """Create a ComponentSpec from a load_id string.""" - if load_id == "null": - raise ValueError("Cannot create ComponentSpec from null load_id") - - # Decode the load_id into a dictionary of loading fields - load_fields = cls.decode_load_id(load_id) - - # Create a new ComponentSpec instance with the decoded fields - return cls(name=name, **load_fields) - - @classmethod - def loading_fields(cls) -> List[str]: - """ - Return the names of all loading‐related fields - (i.e. those whose field.metadata["loading"] is True). - """ - return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - - @property - def load_id(self) -> str: - """ - Unique identifier for this spec's pretrained load, - composed of repo|subfolder|variant|revision (no empty segments). - """ - parts = [getattr(self, k) for k in self.loading_fields()] - parts = ["null" if p is None else p for p in parts] - return "|".join(p for p in parts if p) - - @classmethod - def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: - """ - Decode a load_id string back into a dictionary of loading fields and values. - - Args: - load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" - where None values are represented as "null" - - Returns: - Dict mapping loading field names to their values. e.g. - { - "repo": "path/to/repo", - "subfolder": "subfolder", - "variant": "variant", - "revision": "revision" - } - If a segment value is "null", it's replaced with None. - Returns None if load_id is "null" (indicating component not loaded from pretrained). - """ - - # Get all loading fields in order - loading_fields = cls.loading_fields() - result = {f: None for f in loading_fields} - - if load_id == "null": - return result - - # Split the load_id - parts = load_id.split("|") - - # Map parts to loading fields by position - for i, part in enumerate(parts): - if i < len(loading_fields): - # Convert "null" string back to None - result[loading_fields[i]] = None if part == "null" else part - - return result - - # YiYi TODO: add validator - def create(self, **kwargs) -> Any: - """Create the component using the preferred creation method.""" - - # from_pretrained creation - if self.default_creation_method == "from_pretrained": - return self.create_from_pretrained(**kwargs) - elif self.default_creation_method == "from_config": - # from_config creation - return self.create_from_config(**kwargs) - else: - raise ValueError(f"Invalid creation method: {self.default_creation_method}") - - def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: - """Create component using from_config with config.""" - - if self.type_hint is None or not isinstance(self.type_hint, type): - raise ValueError( - f"`type_hint` is required when using from_config creation method." - ) - - config = config or self.config or {} - - if issubclass(self.type_hint, ConfigMixin): - component = self.type_hint.from_config(config, **kwargs) - else: - signature_params = inspect.signature(self.type_hint.__init__).parameters - init_kwargs = {} - for k, v in config.items(): - if k in signature_params: - init_kwargs[k] = v - for k, v in kwargs.items(): - if k in signature_params: - init_kwargs[k] = v - component = self.type_hint(**init_kwargs) - - component._diffusers_load_id = "null" - if hasattr(component, "config"): - self.config = component.config - - return component - - # YiYi TODO: add guard for type of model, if it is supported by from_pretrained - def create_from_pretrained(self, **kwargs) -> Any: - """Create component using from_pretrained.""" - - passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} - load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} - # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path - repo = load_kwargs.pop("repo", None) - if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - - if self.type_hint is None: - try: - from diffusers import AutoModel - component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") - self.type_hint = component.__class__ - else: - try: - component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") - - if repo != self.repo: - self.repo = repo - for k, v in passed_loading_kwargs.items(): - if v is not None: - setattr(self, k, v) - component._diffusers_load_id = self.load_id - - return component - - - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None -@dataclass -class InputParam: - """Specification for an input parameter.""" - name: str = None - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - - -@dataclass -class OutputParam: - """Specification for an output parameter.""" - name: str - type_hint: Any = None - description: str = "" - kwargs_type: str = None - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - if inp.name is None and inp.kwargs_type is not None: - inp_name = "*_" + inp.kwargs_type - else: - inp_name = inp.name - input_parts.append(inp_name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params, header="Args", indent_level=4, max_line_length=115): - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text, indent, max_length): - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if hasattr(param, "required"): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - - -def format_input_params(input_params, indent_level=4, max_line_length=115): - """Format a list of InputParam objects into a readable string representation. - - Args: - input_params: List of InputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all input parameters - """ - return format_params(input_params, "Inputs", indent_level, max_line_length) - - -def format_output_params(output_params, indent_level=4, max_line_length=115): - """Format a list of OutputParam objects into a readable string representation. - - Args: - output_params: List of OutputParam objects to format - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all output parameters - """ - return format_params(output_params, "Outputs", indent_level, max_line_length) - - -def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ComponentSpec objects into a readable string representation. - - Args: - components: List of ComponentSpec objects to format - indent_level: Number of spaces to indent each component line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between components (default: True) - - Returns: - A formatted string representing all components - """ - if not components: - return "" - - base_indent = " " * indent_level - component_indent = " " * (indent_level + 4) - formatted_components = [] - - # Add the header - formatted_components.append(f"{base_indent}Components:") - if add_empty_lines: - formatted_components.append("") - - # Add each component with optional empty lines between them - for i, component in enumerate(components): - # Get type name, handling special cases - type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - - component_desc = f"{component_indent}{component.name} (`{type_name}`)" - if component.description: - component_desc += f": {component.description}" - - # Get the loading fields dynamically - loading_field_values = [] - for field_name in component.loading_fields(): - field_value = getattr(component, field_name) - if field_value is not None: - loading_field_values.append(f"{field_name}={field_value}") - - # Add loading field information if available - if loading_field_values: - component_desc += f" [{', '.join(loading_field_values)}]" - - formatted_components.append(component_desc) - - # Add an empty line after each component except the last one - if add_empty_lines and i < len(components) - 1: - formatted_components.append("") - - return "\n".join(formatted_components) - - -def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): - """Format a list of ConfigSpec objects into a readable string representation. - - Args: - configs: List of ConfigSpec objects to format - indent_level: Number of spaces to indent each config line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between configs (default: True) - - Returns: - A formatted string representing all configs - """ - if not configs: - return "" - - base_indent = " " * indent_level - config_indent = " " * (indent_level + 4) - formatted_configs = [] - - # Add the header - formatted_configs.append(f"{base_indent}Configs:") - if add_empty_lines: - formatted_configs.append("") - - # Add each config with optional empty lines between them - for i, config in enumerate(configs): - config_desc = f"{config_indent}{config.name} (default: {config.default})" - if config.description: - config_desc += f": {config.description}" - formatted_configs.append(config_desc) - - # Add an empty line after each config except the last one - if add_empty_lines and i < len(configs) - 1: - formatted_configs.append("") - - return "\n".join(formatted_configs) - - -def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Args: - inputs: List of input parameters - intermediates_inputs: List of intermediate input parameters - outputs: List of output parameters - description (str, *optional*): Description of the block - class_name (str, *optional*): Name of the class to include in the documentation - expected_components (List[ComponentSpec], *optional*): List of expected components - expected_configs (List[ConfigSpec], *optional*): List of expected configurations - - Returns: - str: A formatted string containing information about components, configs, call parameters, - intermediate inputs/outputs, and final outputs. - """ - output = "" - - # Add class name if provided - if class_name: - output += f"class {class_name}\n\n" - - # Add description - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - # Add components section if provided - if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) - output += components_str + "\n\n" - - # Add configs section if provided - if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) - output += configs_str + "\n\n" - - # Add inputs section - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - # Add outputs section - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index acb3953450..0000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3032 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor, unwrap_module -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularLoader, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, - ComponentSpec, - ConfigSpec, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance -from ...configuration_utils import FrozenDict - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? -# YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader -## (2) acts like a container that holds components and configs -## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents -## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) -## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -class StableDiffusionXLIPAdapterStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - PipelineImageInput, - required=True, - description="The image(s) to be used as ip adapter" - ) - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components - @staticmethod - def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(components.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = components.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = components.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = components.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds - ): - image_embeds = [] - if prepare_unconditional_embeds: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - components, single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if prepare_unconditional_embeds: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if prepare_unconditional_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if prepare_unconditional_embeds: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( - components, - ip_adapter_image=block_state.ip_adapter_image, - ip_adapter_image_embeds=None, - device=block_state.device, - num_images_per_prompt=1, - prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, - ) - if block_state.prepare_unconditional_embeds: - block_state.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(block_state.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - block_state.negative_ip_adapter_embeds.append(negative_image_embeds) - block_state.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("text_encoder", CLIPTextModel), - ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), - ComponentSpec("tokenizer", CLIPTokenizer), - ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("force_zeros_for_empty_prompt", True)] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("prompt"), - InputParam("prompt_2"), - InputParam("negative_prompt"), - InputParam("negative_prompt_2"), - InputParam("cross_attention_kwargs"), - InputParam("clip_skip"), - ] - - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - ] - - @staticmethod - def check_inputs(block_state): - - if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") - - @staticmethod - def encode_prompt( - components, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prepare_unconditional_embeds: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prepare_unconditional_embeds (`bool`): - whether to use prepare unconditional embeddings or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or components._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): - components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if components.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) - else: - scale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) - else: - scale_lora_layers(components.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] - text_encoders = ( - [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - prompt = components.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif prepare_unconditional_embeds and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(components, TextualInversionLoaderMixin): - negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if components.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if components.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if prepare_unconditional_embeds: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if components.text_encoder is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) - - if components.text_encoder_2 is not None: - if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - block_state = self.get_block_state(state) - self.check_inputs(block_state) - - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 - block_state.device = components._execution_device - - # Encode input prompt - block_state.text_encoder_lora_scale = ( - block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None - ) - ( - block_state.prompt_embeds, - block_state.negative_prompt_embeds, - block_state.pooled_prompt_embeds, - block_state.negative_pooled_prompt_embeds, - ) = self.encode_prompt( - components, - block_state.prompt, - block_state.prompt_2, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - block_state.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=block_state.text_encoder_lora_scale, - clip_skip=block_state.clip_skip, - ) - # Add outputs - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("generator"), - InputParam("height"), - InputParam("width"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - - block_state.batch_size = block_state.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ComponentSpec( - "mask_processor", - VaeImageProcessor, - config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), - default_creation_method="from_config"), - ] - - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - if block_state.padding_mask_crop is not None: - block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) - block_state.resize_mode = "fill" - else: - block_state.crops_coords = None - block_state.resize_mode = "default" - - block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) - block_state.image = block_state.image.to(dtype=torch.float32) - - block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) - - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image, - block_state.batch_size, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - - return components, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, components, block_state): - - if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: - if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {block_state.negative_prompt_embeds.shape}." - ) - - if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {block_state.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype - - _, seq_len, _ = block_state.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - if block_state.negative_prompt_embeds is not None: - _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) - - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.negative_pooled_prompt_embeds is not None: - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - - if block_state.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): - block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - if block_state.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): - block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - InputParam("strength", default=0.3), - InputParam("denoising_start"), - # YiYi TODO: do we need num_images_per_prompt here? - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components - def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start * components.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (denoising_start * components.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if components.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(components.scheduler.timesteps) - num_inference_steps - timesteps = components.scheduler.timesteps[t_start:] - if hasattr(components.scheduler, "set_begin_index"): - components.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( - components, - block_state.num_inference_steps, - block_state.strength, - block_state.device, - denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, - ) - block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("denoising_end"), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.device = components._execution_device - - block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas - ) - - if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: - block_state.discrete_timestep_cutoff = int( - round( - components.scheduler.config.num_train_timesteps - - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) - ) - ) - block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) - block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - InputParam( - "strength", - default=0.9999, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - @staticmethod - def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - image_latents = components.vae.config.scaling_factor * image_latents - - return image_latents - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument - def prepare_latents_inpaint( - self, - components, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(components, image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * components.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, components, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - - block_state.is_strength_max = block_state.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(components,"unet") and components.unet is not None: - if components.unet.config.in_channels == 4: - block_state.masked_image_latents = None - - block_state.add_noise = True if block_state.denoising_start is None else False - - block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor - block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - - block_state.latents, block_state.noise = self.prepare_latents_inpaint( - components, - block_state.batch_size * block_state.num_images_per_prompt, - components.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - image=block_state.image_latents, - timestep=block_state.latent_timestep, - is_strength_max=block_state.is_strength_max, - add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( - components, - block_state.mask, - block_state.masked_image_latents, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - InputParam("denoising_start"), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components - # YiYi TODO: refactor using _encode_vae_image - @staticmethod - def prepare_latents_img2img( - components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: - latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: - latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if components.vae.config.force_upcast: - image = image.float() - components.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(components.vae.encode(image), generator=generator) - - if components.vae.config.force_upcast: - components.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std - else: - init_latents = components.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = components.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device - block_state.add_noise = True if block_state.denoising_start is None else False - if block_state.latents is None: - block_state.latents = self.prepare_latents_img2img( - components, - block_state.image_latents, - block_state.latent_timestep, - block_state.batch_size, - block_state.num_images_per_prompt, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.add_noise, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("latents"), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - - @staticmethod - def check_inputs(components, block_state): - if ( - block_state.height is not None - and block_state.height % components.vae_scale_factor != 0 - or block_state.width is not None - and block_state.width % components.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components - @staticmethod - def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // components.vae_scale_factor, - int(width) // components.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * components.scheduler.init_noise_sigma - return latents - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.dtype is None: - block_state.dtype = components.vae.dtype - - block_state.device = components._execution_device - - self.check_inputs(components, block_state) - - block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor - block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor - block_state.num_channels_latents = components.num_channels_latents - block_state.latents = self.prepare_latents( - components, - block_state.batch_size * block_state.num_images_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, - ) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", False),] - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - InputParam("aesthetic_score", default=6.0), - InputParam("negative_aesthetic_score", default=2.0), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids_img2img( - components, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if components.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.vae_scale_factor = components.vae_scale_factor - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * block_state.vae_scale_factor - block_state.width = block_state.width * block_state.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - if block_state.negative_original_size is None: - block_state.negative_original_size = block_state.original_size - if block_state.negative_target_size is None: - block_state.negative_target_size = block_state.target_size - - block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.aesthetic_score, - block_state.negative_aesthetic_score, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - dtype=block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("original_size"), - InputParam("target_size"), - InputParam("negative_original_size"), - InputParam("negative_target_size"), - InputParam("crops_coords_top_left", default=(0, 0)), - InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components - @staticmethod - def _get_add_time_ids( - components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - @staticmethod - def get_guidance_scale_embedding( - w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - block_state.device = components._execution_device - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) - - block_state.add_time_ids = self._get_add_time_ids( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - if block_state.negative_original_size is not None and block_state.negative_target_size is not None: - block_state.negative_add_time_ids = self._get_add_time_ids( - components, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - block_state.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=block_state.text_encoder_projection_dim, - ) - else: - block_state.negative_add_time_ids = block_state.add_time_ids - - block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) - - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) - - self.add_block_state(state, block_state) - return components, state - -class StableDiffusionXLControlNetInputStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepare inputs for controlnet" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), - OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), - OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - # (1) prepare controlnet inputs - block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - block_state.control_guidance_start, block_state.control_guidance_end = ( - mult * [block_state.control_guidance_start], - mult * [block_state.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( - components, - image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in block_state.control_image: - control_image = self.prepare_control_image( - components, - image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - - control_images.append(control_image) - - block_state.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - keeps = [ - 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) - for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) - ] - block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - - - self.add_block_state(state, block_state) - - return components, state - -class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "step that prepares inputs for the ControlNetUnion model" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("control_image", required=True), - InputParam("control_mode", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), - InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of model tensor inputs. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), - OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), - OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), - OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), - OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), - OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), - OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), - OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), - ] - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - - block_state = self.get_block_state(state) - - controlnet = unwrap_module(components.controlnet) - - device = components._execution_device - dtype = block_state.dtype or components.controlnet.dtype - - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - - - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # control_image - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - # control_mode - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - # control_type - block_state.num_control_type = controlnet.config.num_control_type - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - block_state.control_type = torch.Tensor(block_state.control_type) - - block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=device, - dtype=dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) - block_state.control_type_idx = block_state.control_mode - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): - - block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] - block_names = ["controlnet_union", "controlnet"] - block_trigger_inputs = ["control_mode", "control_image"] - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), - default_creation_method="from_config"), - ] - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components - @staticmethod - def upcast_vae(components): - dtype = components.vae.dtype - components.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - components.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - components.vae.post_quant_conv.to(dtype) - components.vae.decoder.conv_in.to(dtype) - components.vae.decoder.mid_block.to(dtype) - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - - if block_state.needs_upcasting: - self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - block_state.has_latents_mean = ( - hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None - ) - block_state.has_latents_std = ( - hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None - ) - if block_state.has_latents_mean and block_state.has_latents_std: - block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) - ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean - else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor - - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if block_state.needs_upcasting: - components.vae.to(dtype=torch.float16) - else: - block_state.images = block_state.latents - - # apply watermark if available - if hasattr(components, "watermark") and components.watermark is not None: - block_state.images = components.watermark.apply_watermark(block_state.images) - - block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("image", required=True), - InputParam("mask_image", required=True), - InputParam("padding_mask_crop"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: - block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] - - self.add_block_state(state, block_state) - - return components, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, components, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - if not block_state.return_dict: - block_state.images = (block_state.images,) - else: - block_state.images = StableDiffusionXLPipelineOutput(images=block_state.images) - self.add_block_state(state, block_state) - return components, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ - " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ - " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." - -# # Denoise -from .pipeline_stable_diffusion_xl_modular_denoise_loop import StableDiffusionXLDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLAutoDenoiseStep -# class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): -# block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] -# block_names = ["controlnet_union", "controlnet", "unet"] -# block_trigger_inputs = ["control_mode", "control_image", None] - -# @property -# def description(self): -# return "Denoise step that denoise the latents.\n" + \ -# "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ -# " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ -# " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ -# " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that -# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by -# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the -# configuration of guider is. - - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - - -# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), - "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), - "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), - "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), - "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), - "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), - "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), - "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), - "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), - "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), - "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), - "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), - "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), - "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), - "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), - "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), - "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), - "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), - "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), - "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), - "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), - "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), - "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), - "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), - "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), - "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), - "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), - "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), - "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), - "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") -} - - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), - "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), - "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), - "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), - "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") -} - - -SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { - "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), - "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), - "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), - "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), - "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), - "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), - "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), - "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), - "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), - "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), - "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), - "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), - "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), - "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), - "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), - "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), - "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") -} - - -SDXL_OUTPUTS_SCHEMA = { - "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py deleted file mode 100644 index 63d0784a57..0000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular_denoise_loop.py +++ /dev/null @@ -1,1363 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from tqdm.auto import tqdm - -from ...configuration_utils import FrozenDict -from ...models import ControlNetModel, UNet2DConditionModel -from ...schedulers import EulerDiscreteScheduler -from ...utils import logging -from ...utils.torch_utils import unwrap_module -from ..modular_pipeline import ( - PipelineBlock, - PipelineState, - AutoPipelineBlocks, - LoopSequentialPipelineBlocks, - InputParam, - OutputParam, - BlockState, - ComponentSpec, -) -from ...guiders import ClassifierFreeGuidance -from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader -from dataclasses import asdict - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -# YiYi experimenting composible denoise loop -# loop step (1): prepare latent input for denoiser -class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - - - return components, block_state - -# loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step within the denoising loop that prepare the latent input for the denoiser" - - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")] - - @staticmethod - def check_inputs(components, block_state): - - num_channels_unet = components.num_channels_unet - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if block_state.mask is None or block_state.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = block_state.latents.shape[1] - num_channels_mask = block_state.mask.shape[1] - num_channels_masked_image = block_state.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" - f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `components.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - - return components, block_state - -# loop step (2): denoise the latents with guidance -class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return ( - "Step within the denoising loop that denoise the latents with guidance" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} - prompt_embeds = cond_kwargs.pop("prompt_embeds") - - # Predict the noise residual - # store the noise_pred in guider_state_batch so that we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=cond_kwargs, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("cross_attention_kwargs"), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "controlnet_cond", - required=True, - type_hint=torch.Tensor, - description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "guess_mode", - required=True, - type_hint=bool, - description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "controlnet_keep", - required=True, - type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." - ), - InputParam( - "scaled_latents", - required=True, - type_hint=torch.Tensor, - description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - kwargs_type="guider_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds, " - "add_time_ids/negative_add_time_ids, " - "pooled_prompt_embeds/negative_pooled_prompt_embeds, " - "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." - "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ), - InputParam( - kwargs_type="controlnet_kwargs", - description=( - "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" - "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ) - ) - ] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) - - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_input_fields ={ - "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), - "time_ids": ("add_time_ids", "negative_add_time_ids"), - "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), - "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), - } - - - # cond_scale for the timestep (controlnet input) - if isinstance(block_state.controlnet_keep[i], list): - block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] - else: - controlnet_cond_scale = block_state.conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] - - # default controlnet output/unet input for guess mode + conditional path - block_state.down_block_res_samples_zeros = None - block_state.mid_block_res_sample_zeros = None - - # guided denoiser step - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - - # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. - # e.g. for CFG, we prepare two batches: one for uncond, one for cond - # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds - # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds - guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) - - # run the denoiser for each guidance batch - for guider_state_batch in guider_state: - components.guider.prepare_models(components.unet) - - # Prepare additional conditionings - added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: - added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds - - # Prepare controlnet additional conditionings - controlnet_added_cond_kwargs = { - "text_embeds": guider_state_batch.text_embeds, - "time_ids": guider_state_batch.time_ids, - } - # run controlnet for the guidance batch - if block_state.guess_mode and not components.guider.is_conditional: - # guider always run uncond batch first, so these tensors should be set already - down_block_res_samples = block_state.down_block_res_samples_zeros - mid_block_res_sample = block_state.mid_block_res_sample_zeros - else: - down_block_res_samples, mid_block_res_sample = components.controlnet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - controlnet_cond=block_state.controlnet_cond, - conditioning_scale=block_state.cond_scale, - guess_mode=block_state.guess_mode, - added_cond_kwargs=controlnet_added_cond_kwargs, - return_dict=False, - **extra_controlnet_kwargs, - ) - - # assign it to block_state so it will be available for the uncond guidance batch - if block_state.down_block_res_samples_zeros is None: - block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] - if block_state.mid_block_res_sample_zeros is None: - block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) - - # Predict the noise - # store the noise_pred in guider_state_batch so we can apply guidance across all batches - guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, - t, - encoder_hidden_states=guider_state_batch.prompt_embeds, - timestep_cond=block_state.timestep_cond, - cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - components.guider.cleanup_models(components.unet) - - # Perform guidance - block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) - - return components, block_state - -# loop step (3): scheduler step to update latents -class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - #YiYi TODO: move this out of here - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - return components, block_state - -# loop step (3): scheduler step to update latents (with inpainting) -class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): - - model_name = "stable-diffusion-xl" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), - InputParam("eta", default=0.0), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - @staticmethod - def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - - accepted_kwargs = set(inspect.signature(func).parameters.keys()) - extra_kwargs = {} - for key, value in kwargs.items(): - if key in accepted_kwargs and key not in exclude_kwargs: - extra_kwargs[key] = value - - return extra_kwargs - - def check_inputs(self, components, block_state): - if components.num_channels_unet == 4: - if block_state.image_latents is None: - raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") - if block_state.mask is None: - raise ValueError(f"mask is required for this step {self.__class__.__name__}") - if block_state.noise is None: - raise ValueError(f"noise is required for this step {self.__class__.__name__}") - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): - - self.check_inputs(components, block_state) - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) - - - # Perform scheduler step using the predicted output - block_state.latents_dtype = block_state.latents.dtype - block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] - - if block_state.latents.dtype != block_state.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - block_state.latents = block_state.latents.to(block_state.latents_dtype) - - # adjust latent for inpainting - if components.num_channels_unet == 4: - block_state.init_latents_proper = block_state.image_latents - if i < len(block_state.timesteps) - 1: - block_state.noise_timestep = block_state.timesteps[i + 1] - block_state.init_latents_proper = components.scheduler.add_noise( - block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) - ) - - block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - - - - return components, block_state - - -# the loop wrapper that iterates over the timesteps -class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): - - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def loop_expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), - ] - - @property - def loop_intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - ] - - - @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False - if block_state.disable_guidance: - components.guider.disable() - else: - components.guider.enable() - - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: - for i, t in enumerate(block_state.timesteps): - components, block_state = self.loop_step(components, block_state, i=i, t=t) - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): - progress_bar.update() - - self.add_block_state(state, block_state) - - return components, state - - -# composing the denoising loops -class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond -class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# mask -class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - -# control_cond + mask -class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - - -# all task without controlnet -class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - -# all task with controlnet -class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] - block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] - block_trigger_inputs = ["mask", None] - -# all task with or without controlnet -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] - - - - - - - -# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible -# class StableDiffusionXLDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ] - -# @property -# def description(self) -> str: -# return ( -# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" -# ) - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("num_images_per_prompt", default=1), -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) - -# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components -# @staticmethod -# def prepare_extra_step_kwargs(components, generator, eta): -# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature -# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. -# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 -# # and should be between [0, 1] - -# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# extra_step_kwargs = {} -# if accepts_eta: -# extra_step_kwargs["eta"] = eta - -# # check if the scheduler accepts generator -# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) -# if accepts_generator: -# extra_step_kwargs["generator"] = generator -# return extra_step_kwargs - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) - -# block_state.num_channels_unet = components.unet.config.in_channels -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_data = components.guider.prepare_inputs(block_state) - -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) - -# # Prepare for inpainting -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - -# for batch in guider_data: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# batch.added_cond_kwargs = { -# "text_embeds": batch.pooled_prompt_embeds, -# "time_ids": batch.add_time_ids, -# } -# if batch.ip_adapter_embeds is not None: -# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - -# # Predict the noise residual -# batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=batch.added_cond_kwargs, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state - - - -# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - -# model_name = "stable-diffusion-xl" - -# @property -# def expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec( -# "guider", -# ClassifierFreeGuidance, -# config=FrozenDict({"guidance_scale": 7.5}), -# default_creation_method="from_config"), -# ComponentSpec("scheduler", EulerDiscreteScheduler), -# ComponentSpec("unet", UNet2DConditionModel), -# ComponentSpec("controlnet", ControlNetModel), -# ] - -# @property -# def description(self) -> str: -# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - -# @property -# def inputs(self) -> List[Tuple[str, Any]]: -# return [ -# InputParam("num_images_per_prompt", default=1), -# InputParam("cross_attention_kwargs"), -# InputParam("generator"), -# InputParam("eta", default=0.0), -# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) -# ] - -# @property -# def intermediates_inputs(self) -> List[str]: -# return [ -# InputParam( -# "controlnet_cond", -# required=True, -# type_hint=torch.Tensor, -# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_start", -# required=True, -# type_hint=float, -# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "control_guidance_end", -# required=True, -# type_hint=float, -# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "conditioning_scale", -# type_hint=float, -# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "guess_mode", -# required=True, -# type_hint=bool, -# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "controlnet_keep", -# required=True, -# type_hint=List[float], -# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." -# ), -# InputParam( -# "latents", -# required=True, -# type_hint=torch.Tensor, -# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." -# ), -# InputParam( -# "batch_size", -# required=True, -# type_hint=int, -# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." -# ), -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam( -# "prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "add_time_ids", -# required=True, -# type_hint=torch.Tensor, -# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "negative_add_time_ids", -# type_hint=Optional[torch.Tensor], -# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." -# ), -# InputParam( -# "pooled_prompt_embeds", -# required=True, -# type_hint=torch.Tensor, -# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "negative_pooled_prompt_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." -# ), -# InputParam( -# "timestep_cond", -# type_hint=Optional[torch.Tensor], -# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" -# ), -# InputParam( -# "mask", -# type_hint=Optional[torch.Tensor], -# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "masked_image_latents", -# type_hint=Optional[torch.Tensor], -# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "noise", -# type_hint=Optional[torch.Tensor], -# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." -# ), -# InputParam( -# "image_latents", -# type_hint=Optional[torch.Tensor], -# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." -# ), -# InputParam( -# "crops_coords", -# type_hint=Optional[Tuple[int]], -# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." -# ), -# InputParam( -# "ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "negative_ip_adapter_embeds", -# type_hint=Optional[torch.Tensor], -# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." -# ), -# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") -# ] - -# @property -# def intermediates_outputs(self) -> List[OutputParam]: -# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - -# @staticmethod -# def check_inputs(components, block_state): - -# num_channels_unet = components.unet.config.in_channels -# if num_channels_unet == 9: -# # default case for runwayml/stable-diffusion-inpainting -# if block_state.mask is None or block_state.masked_image_latents is None: -# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") -# num_channels_latents = block_state.latents.shape[1] -# num_channels_mask = block_state.mask.shape[1] -# num_channels_masked_image = block_state.masked_image_latents.shape[1] -# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: -# raise ValueError( -# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" -# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" -# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" -# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" -# " `components.unet` or your `mask_image` or `image` input." -# ) -# @staticmethod -# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): - -# accepted_kwargs = set(inspect.signature(func).parameters.keys()) -# extra_kwargs = {} -# for key, value in kwargs.items(): -# if key in accepted_kwargs and key not in exclude_kwargs: -# extra_kwargs[key] = value - -# return extra_kwargs - - -# @torch.no_grad() -# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: - -# block_state = self.get_block_state(state) -# self.check_inputs(components, block_state) -# block_state.device = components._execution_device -# print(f" block_state: {block_state}") - -# controlnet = unwrap_module(components.controlnet) - -# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline -# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) -# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) - -# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - -# # (1) setup guider -# # disable for LCMs -# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False -# if block_state.disable_guidance: -# components.guider.disable() -# else: -# components.guider.enable() -# components.guider.set_input_fields( -# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), -# add_time_ids=("add_time_ids", "negative_add_time_ids"), -# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), -# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), -# ) - -# # (5) Denoise loop -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# # prepare latent input for unet -# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) -# # adjust latent input for inpainting -# block_state.num_channels_unet = components.unet.config.in_channels -# if block_state.num_channels_unet == 9: -# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - -# # cond_scale (controlnet input) -# if isinstance(block_state.controlnet_keep[i], list): -# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] -# else: -# block_state.controlnet_cond_scale = block_state.conditioning_scale -# if isinstance(block_state.controlnet_cond_scale, list): -# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] -# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - -# # default controlnet output/unet input for guess mode + conditional path -# block_state.down_block_res_samples_zeros = None -# block_state.mid_block_res_sample_zeros = None - -# # guided denoiser step -# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) -# guider_state = components.guider.prepare_inputs(block_state) - -# for guider_state_batch in guider_state: -# components.guider.prepare_models(components.unet) - -# # Prepare additional conditionings -# guider_state_batch.added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } -# if guider_state_batch.ip_adapter_embeds is not None: -# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds - -# # Prepare controlnet additional conditionings -# guider_state_batch.controlnet_added_cond_kwargs = { -# "text_embeds": guider_state_batch.pooled_prompt_embeds, -# "time_ids": guider_state_batch.add_time_ids, -# } - -# if block_state.guess_mode and not components.guider.is_conditional: -# # guider always run uncond batch first, so these tensors should be set already -# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros -# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros -# else: -# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# controlnet_cond=block_state.controlnet_cond, -# conditioning_scale=block_state.conditioning_scale, -# guess_mode=block_state.guess_mode, -# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, -# return_dict=False, -# **block_state.extra_controlnet_kwargs, -# ) - -# if block_state.down_block_res_samples_zeros is None: -# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] -# if block_state.mid_block_res_sample_zeros is None: -# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - - - -# guider_state_batch.noise_pred = components.unet( -# block_state.scaled_latents, -# t, -# encoder_hidden_states=guider_state_batch.prompt_embeds, -# timestep_cond=block_state.timestep_cond, -# cross_attention_kwargs=block_state.cross_attention_kwargs, -# added_cond_kwargs=guider_state_batch.added_cond_kwargs, -# down_block_additional_residuals=guider_state_batch.down_block_res_samples, -# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, -# return_dict=False, -# )[0] -# components.guider.cleanup_models(components.unet) - -# # Perform guidance -# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) - -# # Perform scheduler step using the predicted output -# block_state.latents_dtype = block_state.latents.dtype -# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - -# if block_state.latents.dtype != block_state.latents_dtype: -# if torch.backends.mps.is_available(): -# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 -# block_state.latents = block_state.latents.to(block_state.latents_dtype) - -# # adjust latent for inpainting -# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: -# block_state.init_latents_proper = block_state.image_latents -# if i < len(block_state.timesteps) - 1: -# block_state.noise_timestep = block_state.timesteps[i + 1] -# block_state.init_latents_proper = components.scheduler.add_noise( -# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) -# ) - -# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents - -# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): -# progress_bar.update() - -# self.add_block_state(state, block_state) - -# return components, state \ No newline at end of file