1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

seperate controlnet step into input + denoise

This commit is contained in:
yiyixuxu
2025-05-03 20:22:05 +02:00
parent 7ca860c24b
commit efd70b7838

View File

@@ -2395,27 +2395,20 @@ class StableDiffusionXLDenoiseStep(PipelineBlock):
return components, state
class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
class StableDiffusionXLControlNetInputStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"),
]
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
return "step that prepare inputs for controlnet"
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -2426,9 +2419,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
InputParam("eta", default=0.0),
]
@property
@@ -2452,12 +2442,246 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"crops_coords",
type_hint=Optional[Tuple[int]],
description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step."
),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"),
OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"),
OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"),
OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"),
OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
]
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
# 1. return image without apply any guidance
# 2. add crops_coords and resize_mode to preprocess()
@staticmethod
def prepare_control_image(
components,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
crops_coords=None,
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
return image
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# (1) prepare controlnet inputs
block_state.device = components._execution_device
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
controlnet = unwrap_module(components.controlnet)
# (1.1)
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
block_state.control_guidance_start, block_state.control_guidance_end = (
mult * [block_state.control_guidance_start],
mult * [block_state.control_guidance_end],
)
# (1.2)
# controlnet_conditioning_scale (align format)
if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float):
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
# (1.3)
# global_pool_conditions
block_state.global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
# (1.4)
# guess_mode
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
# (1.5)
# control_image
if isinstance(controlnet, ControlNetModel):
block_state.control_image = self.prepare_control_image(
components,
image=block_state.control_image,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
crops_coords=block_state.crops_coords,
)
elif isinstance(controlnet, MultiControlNetModel):
control_images = []
for control_image_ in block_state.control_image:
control_image = self.prepare_control_image(
components,
image=control_image_,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
crops_coords=block_state.crops_coords,
)
control_images.append(control_image)
block_state.control_image = control_images
else:
assert False
# (1.6)
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
keeps = [
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
]
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
]
@property
def description(self) -> str:
return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam("num_images_per_prompt", default=1),
InputParam("cross_attention_kwargs"),
InputParam("generator"),
InputParam("eta", default=0.0),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"control_image",
required=True,
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"control_guidance_start",
required=True,
type_hint=float,
description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"control_guidance_end",
required=True,
type_hint=float,
description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"controlnet_conditioning_scale",
required=True,
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"guess_mode",
required=True,
type_hint=bool,
description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step."
),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam(
"prompt_embeds",
required=True,
@@ -2557,36 +2781,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
" `components.unet` or your `mask_image` or `image` input."
)
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
# 1. return image without apply any guidance
# 2. add crops_coords and resize_mode to preprocess()
@staticmethod
def prepare_control_image(
components,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
crops_coords=None,
):
if crops_coords is not None:
image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
else:
image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components
@staticmethod
@@ -2613,103 +2807,19 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.num_channels_unet = components.unet.config.in_channels
# (1) prepare controlnet inputs
block_state.device = components._execution_device
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
controlnet = unwrap_module(components.controlnet)
# (1.1)
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list):
block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start]
elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list):
block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end]
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
block_state.control_guidance_start, block_state.control_guidance_end = (
mult * [block_state.control_guidance_start],
mult * [block_state.control_guidance_end],
)
# (1.2)
# controlnet_conditioning_scale (align format)
if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float):
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets)
# (1.3)
# global_pool_conditions
block_state.global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
# (1.4)
# guess_mode
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
# (1.5)
# control_image
if isinstance(controlnet, ControlNetModel):
block_state.control_image = self.prepare_control_image(
components,
image=block_state.control_image,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
crops_coords=block_state.crops_coords,
)
elif isinstance(controlnet, MultiControlNetModel):
control_images = []
for control_image_ in block_state.control_image:
control_image = self.prepare_control_image(
components,
image=control_image_,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
crops_coords=block_state.crops_coords,
)
control_images.append(control_image)
block_state.control_image = control_images
else:
assert False
# (1.6)
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
keeps = [
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
]
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# (2) Prepare conditional inputs for unet using the guider
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
# (1) setup guider
# disable for LCMs
block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if block_state.disable_guidance:
components.guider.disable()
else:
components.guider.enable()
# (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta)
block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0)
components.guider.set_input_fields(
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
add_time_ids=("add_time_ids", "negative_add_time_ids"),
@@ -2720,11 +2830,16 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# (5) Denoise loop
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_data = components.guider.prepare_inputs(block_state)
# prepare latent input for unet
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
# adjust latent input for inpainting
block_state.num_channels_unet = components.unet.config.in_channels
if block_state.num_channels_unet == 9:
block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1)
# cond_scale (controlnet input)
if isinstance(block_state.controlnet_keep[i], list):
block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])]
else:
@@ -2733,62 +2848,69 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0]
block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i]
for batch in guider_data:
# default controlnet output/unet input for guess mode + conditional path
block_state.down_block_res_samples_zeros = None
block_state.mid_block_res_sample_zeros = None
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
# Prepare additional conditionings
batch.added_cond_kwargs = {
"text_embeds": batch.pooled_prompt_embeds,
"time_ids": batch.add_time_ids,
guider_state_batch.added_cond_kwargs = {
"text_embeds": guider_state_batch.pooled_prompt_embeds,
"time_ids": guider_state_batch.add_time_ids,
}
if batch.ip_adapter_embeds is not None:
batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds
if guider_state_batch.ip_adapter_embeds is not None:
guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds
# Prepare controlnet additional conditionings
batch.controlnet_added_cond_kwargs = {
"text_embeds": batch.pooled_prompt_embeds,
"time_ids": batch.add_time_ids,
guider_state_batch.controlnet_added_cond_kwargs = {
"text_embeds": guider_state_batch.pooled_prompt_embeds,
"time_ids": guider_state_batch.add_time_ids,
}
# Will always be run atleast once with every guider
if components.guider.is_conditional or not block_state.guess_mode:
block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet(
if block_state.guess_mode and not components.guider.is_conditional:
# guider always run uncond batch first, so these tensors should be set already
guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros
guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet(
block_state.scaled_latents,
t,
encoder_hidden_states=batch.prompt_embeds,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.control_image,
conditioning_scale=block_state.cond_scale,
guess_mode=block_state.guess_mode,
added_cond_kwargs=batch.controlnet_added_cond_kwargs,
added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs,
return_dict=False,
)
batch.down_block_res_samples = block_state.down_block_res_samples
batch.mid_block_res_sample = block_state.mid_block_res_sample
if block_state.down_block_res_samples_zeros is None:
block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples]
if block_state.mid_block_res_sample_zeros is None:
block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample)
if components.guider.is_unconditional and block_state.guess_mode:
batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples]
batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample)
# Prepare for inpainting
if block_state.num_channels_unet == 9:
block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1)
batch.noise_pred = components.unet(
guider_state_batch.noise_pred = components.unet(
block_state.scaled_latents,
t,
encoder_hidden_states=batch.prompt_embeds,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
cross_attention_kwargs=block_state.cross_attention_kwargs,
added_cond_kwargs=batch.added_cond_kwargs,
down_block_additional_residuals=batch.down_block_res_samples,
mid_block_additional_residual=batch.mid_block_res_sample,
added_cond_kwargs=guider_state_batch.added_cond_kwargs,
down_block_additional_residuals=guider_state_batch.down_block_res_samples,
mid_block_additional_residual=guider_state_batch.mid_block_res_sample,
return_dict=False,
)[0]
components.guider.cleanup_models(components.unet)
# Perform guidance
block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data)
block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state)
# Perform scheduler step using the predicted output
block_state.latents_dtype = block_state.latents.dtype
@@ -2799,6 +2921,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock):
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(block_state.latents_dtype)
# adjust latent for inpainting
if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None:
block_state.init_latents_proper = block_state.image_latents
if i < len(block_state.timesteps) - 1:
@@ -3463,6 +3586,16 @@ class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
" - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \
" - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning"
class StableDiffusionXLControlNetStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep]
block_names = ["prepare_input", "denoise"]
@property
def description(self):
return "Controlnet step that denoise the latents.\n" + \
"This is a sequential pipeline blocks:\n" + \
" - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \
" - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents."
class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep]
@@ -3477,10 +3610,9 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
" - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \
" - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided."
# Denoise
class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep]
block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep]
block_names = ["controlnet_union", "controlnet", "unet"]
block_trigger_inputs = ["control_mode", "control_image", None]
@@ -3489,7 +3621,7 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
return "Denoise step that denoise the latents.\n" + \
"This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \
" - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \
" - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \
" - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \
" - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided."
# After denoise
@@ -3597,7 +3729,7 @@ INPAINT_BLOCKS = OrderedDict([
])
CONTROLNET_BLOCKS = OrderedDict([
("denoise", StableDiffusionXLControlNetDenoiseStep),
("denoise", StableDiffusionXLControlNetStep),
])
CONTROLNET_UNION_BLOCKS = OrderedDict([