1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

refactor the denoiseestep using LoopSequential! also add a new file for denoise step

This commit is contained in:
yiyixuxu
2025-05-08 11:28:52 +02:00
parent d89631fc50
commit 0f0618ff2b
3 changed files with 1568 additions and 556 deletions

View File

@@ -184,6 +184,23 @@ class BlockState:
for key, value in kwargs.items():
setattr(self, key, value)
def __getitem__(self, key: str):
# allows block_state["foo"]
return getattr(self, key, None)
def __setitem__(self, key: str, value: Any):
# allows block_state["foo"] = "bar"
setattr(self, key, value)
def as_dict(self):
"""
Convert BlockState to a dictionary.
Returns:
Dict[str, Any]: Dictionary containing all attributes of the BlockState
"""
return {key: value for key, value in self.__dict__.items()}
def __repr__(self):
def format_value(v):
# Handle tensors directly
@@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name in combined_dict:
current_param = combined_dict[input_param.name]
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (current_param.default is not None and
input_param.default is not None and
current_param.default != input_param.default):
@@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
for block_name, outputs in named_output_lists:
for output_param in outputs:
if output_param.name not in combined_dict:
if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
@@ -919,6 +940,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
return self.get_inputs()
def get_inputs(self):
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
combined_inputs = combine_inputs(*named_inputs)
# mark Required inputs only if that input is required any of the blocks
@@ -931,6 +955,9 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
@property
def intermediates_inputs(self) -> List[str]:
return self.get_intermediates_inputs()
def get_intermediates_inputs(self):
inputs = []
outputs = set()
@@ -1169,7 +1196,262 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
expected_configs=self.expected_configs
)
#YiYi TODO: __repr__
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
"""
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
"""
model_name = None
block_classes = []
block_names = []
@property
def description(self) -> str:
"""Description of the block. Must be implemented by subclasses."""
raise NotImplementedError("description method must be implemented in subclasses")
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return []
@property
def loop_expected_configs(self) -> List[ConfigSpec]:
return []
@property
def loop_inputs(self) -> List[InputParam]:
"""List of input parameters. Must be implemented by subclasses."""
return []
@property
def loop_intermediates_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def loop_intermediates_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
@property
def loop_required_inputs(self) -> List[str]:
input_names = []
for input_param in self.loop_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@property
def loop_required_intermediates_inputs(self) -> List[str]:
input_names = []
for input_param in self.loop_intermediates_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
# modified from SequentialPipelineBlocks to include loop_expected_components
@property
def expected_components(self):
expected_components = []
for block in self.blocks.values():
for component in block.expected_components:
if component not in expected_components:
expected_components.append(component)
for component in self.loop_expected_components:
if component not in expected_components:
expected_components.append(component)
return expected_components
# modified from SequentialPipelineBlocks to include loop_expected_configs
@property
def expected_configs(self):
expected_configs = []
for block in self.blocks.values():
for config in block.expected_configs:
if config not in expected_configs:
expected_configs.append(config)
for config in self.loop_expected_configs:
if config not in expected_configs:
expected_configs.append(config)
return expected_configs
# modified from SequentialPipelineBlocks to include loop_inputs
def get_inputs(self):
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
named_inputs.append(("loop", self.loop_inputs))
combined_inputs = combine_inputs(*named_inputs)
# mark Required inputs only if that input is required any of the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
input_param.required = True
else:
input_param.required = False
return combined_inputs
# Copied from SequentialPipelineBlocks
@property
def inputs(self):
return self.get_inputs()
# modified from SequentialPipelineBlocks to include loop_intermediates_inputs
@property
def intermediates_inputs(self):
intermediates = self.get_intermediates_inputs()
intermediate_names = [input.name for input in intermediates]
for loop_intermediate_input in self.loop_intermediates_inputs:
if loop_intermediate_input.name not in intermediate_names:
intermediates.append(loop_intermediate_input)
return intermediates
# Copied from SequentialPipelineBlocks
def get_intermediates_inputs(self):
inputs = []
outputs = set()
# Go through all blocks in order
for block in self.blocks.values():
# Add inputs that aren't in outputs yet
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
should_add_outputs = False
if should_add_outputs:
# Add this block's outputs
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
outputs.update(block_intermediates_outputs)
return inputs
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
@property
def required_inputs(self) -> List[str]:
# Get the first block from the dictionary
first_block = next(iter(self.blocks.values()))
required_by_any = set(getattr(first_block, "required_inputs", set()))
required_by_loop = set(getattr(self, "loop_required_inputs", set()))
required_by_any.update(required_by_loop)
# Union with required inputs from all other blocks
for block in list(self.blocks.values())[1:]:
block_required = set(getattr(block, "required_inputs", set()))
required_by_any.update(block_required)
return list(required_by_any)
# modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
@property
def required_intermediates_inputs(self) -> List[str]:
required_intermediates_inputs = []
for input_param in self.intermediates_inputs:
if input_param.required:
required_intermediates_inputs.append(input_param.name)
for input_param in self.loop_intermediates_inputs:
if input_param.required:
required_intermediates_inputs.append(input_param.name)
return required_intermediates_inputs
# YiYi TODO: this need to be thought about more
# modified from SequentialPipelineBlocks to include loop_intermediates_outputs
@property
def intermediates_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
for output in self.loop_intermediates_outputs:
if output.name not in set([output.name for output in combined_outputs]):
combined_outputs.append(output)
return combined_outputs
# YiYi TODO: this need to be thought about more
# copied from SequentialPipelineBlocks
@property
def outputs(self) -> List[str]:
return next(reversed(self.blocks.values())).intermediates_outputs
def __init__(self):
blocks = OrderedDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
blocks[block_name] = block_cls()
self.blocks = blocks
def loop_step(self, components, state: PipelineState, **kwargs):
for block_name, block in self.blocks.items():
try:
components, state = block(components, state, **kwargs)
except Exception as e:
error_msg = (
f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
f"Error details: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
)
logger.error(error_msg)
raise
return components, state
def __call__(self, components, state: PipelineState) -> PipelineState:
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediates_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
if intermediates_kwargs:
for k, v in intermediates_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)
def add_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediates_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
# YiYi TODO:
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)

View File

@@ -0,0 +1,729 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from tqdm.auto import tqdm
from ...configuration_utils import FrozenDict
from ...models import ControlNetModel, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import unwrap_module
from ..modular_pipeline import (
PipelineBlock,
PipelineState,
LoopSequentialPipelineBlocks,
InputParam,
OutputParam,
BlockState,
ComponentSpec,
)
from ...guiders import ClassifierFreeGuidance
from .pipeline_stable_diffusion_xl_modular import StableDiffusionXLModularLoader
from dataclasses import asdict
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
class StableDiffusionXLDenoiseLoopLatentsStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return "step within the denoising loop that prepare the latent input for the denoiser"
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int):
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
return components, block_state
# loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLDenoiseLoopInpaintLatentsStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return "step within the denoising loop that prepare the latent input for the denoiser"
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
"masked_image_latents",
type_hint=Optional[torch.Tensor],
description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("scaled_latents", type_hint=torch.Tensor, description="The scaled latents input for denoiser")]
@staticmethod
def check_inputs(components, block_state):
num_channels_unet = components.num_channels_unet
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
if block_state.mask is None or block_state.masked_image_latents is None:
raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
num_channels_latents = block_state.latents.shape[1]
num_channels_mask = block_state.mask.shape[1]
num_channels_masked_image = block_state.masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
raise ValueError(
f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `components.unet` or your `mask_image` or `image` input."
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, loop_idx: int, t: int):
self.check_inputs(components, block_state)
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
if components.num_channels_unet == 9:
block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1)
return components, block_state
# loop step (2): denoise the latents with guidance
class StableDiffusionXLDenoiseLoopDenoiserStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"scaled_latents",
required=True,
type_hint=torch.Tensor,
description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop."
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step."
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
)
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields ={
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=cond_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLDenoiseLoopControlNetDenoiserStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
]
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
required=True,
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"conditioning_scale",
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"guess_mode",
required=True,
type_hint=bool,
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"scaled_latents",
required=True,
type_hint=torch.Tensor,
description="The prepared latents input to use for the denoiser. Can be generated in latent step within the denoise loop."
),
InputParam(
"timestep_cond",
type_hint=Optional[torch.Tensor],
description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step"
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
"please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
)
),
InputParam(
kwargs_type="controlnet_kwargs",
description=(
"additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
"please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
)
)
]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int):
extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs)
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields ={
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
"time_ids": ("add_time_ids", "negative_add_time_ids"),
"text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
"image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
}
# cond_scale for the timestep (controlnet input)
if isinstance(block_state.controlnet_keep[i], list):
block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])]
else:
controlnet_cond_scale = block_state.conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
# default controlnet output/unet input for guess mode + conditional path
block_state.down_block_res_samples_zeros = None
block_state.mid_block_res_sample_zeros = None
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
# Prepare additional conditionings
added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
# Prepare controlnet additional conditionings
controlnet_added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
# run controlnet for the guidance batch
if block_state.guess_mode and not components.guider.is_conditional:
# guider always run uncond batch first, so these tensors should be set already
down_block_res_samples = block_state.down_block_res_samples_zeros
mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.controlnet_cond,
conditioning_scale=block_state.cond_scale,
guess_mode=block_state.guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
**extra_controlnet_kwargs,
)
# assign it to block_state so it will be available for the uncond guidance batch
if block_state.down_block_res_samples_zeros is None:
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
if block_state.mid_block_res_sample_zeros is None:
block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
# Predict the noise
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
# loop step (3): scheduler step to update latents
class StableDiffusionXLDenoiseLoopUpdateLatentsStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("generator"),
InputParam("eta", default=0.0),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
#YiYi TODO: move this out of here
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int):
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
return components, block_state
class StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("generator"),
InputParam("eta", default=0.0),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"mask",
type_hint=Optional[torch.Tensor],
description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step."
),
InputParam(
"noise",
type_hint=Optional[torch.Tensor],
description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step."
),
InputParam(
"image_latents",
type_hint=Optional[torch.Tensor],
description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@staticmethod
def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
accepted_kwargs = set(inspect.signature(func).parameters.keys())
extra_kwargs = {}
for key, value in kwargs.items():
if key in accepted_kwargs and key not in exclude_kwargs:
extra_kwargs[key] = value
return extra_kwargs
def check_inputs(self, components, block_state):
if components.num_channels_unet == 4:
if block_state.image_latents is None:
raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
if block_state.mask is None:
raise ValueError(f"mask is required for this step {self.__class__.__name__}")
if block_state.noise is None:
raise ValueError(f"noise is required for this step {self.__class__.__name__}")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0]
if block_state.latents.dtype != block_state.latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
# adjust latent for inpainting
if components.num_channels_unet == 4:
block_state.init_latents_proper = block_state.image_latents
if i < len(block_state.timesteps) - 1:
block_state.noise_timestep = block_state.timesteps[i + 1]
block_state.init_latents_proper = components.scheduler.add_noise(
block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
)
block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents
return components, block_state
# the loop wrapper that iterates over the timesteps
class StableDiffusionXLDenoiseLoop(LoopSequentialPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
]
@property
def loop_intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
]
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if block_state.disable_guidance:
components.guider.disable()
else:
components.guider.enable()
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0):
progress_bar.update()
self.add_block_state(state, block_state)
return components, state
# StableDiffusionXLControlNetDenoiseStep
class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoop):
block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep]
block_names = ["prepare_latents", "denoiser", "update_latents"]
class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoop):
block_classes = [StableDiffusionXLDenoiseLoopLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopUpdateLatentsStep]
block_names = ["prepare_latents", "denoiser", "update_latents"]
class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoop):
block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep]
block_names = ["prepare_latents", "denoiser", "update_latents"]
class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop):
block_classes = [StableDiffusionXLDenoiseLoopInpaintLatentsStep, StableDiffusionXLDenoiseLoopControlNetDenoiserStep, StableDiffusionXLDenoiseLoopInpaintUpdateLatentsStep]
block_names = ["prepare_latents", "denoiser", "update_latents"]