From efd70b783871aa7b3e02bd8252afbc8e45eeb314 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 3 May 2025 20:22:05 +0200 Subject: [PATCH] seperate controlnet step into input + denoise --- .../pipeline_stable_diffusion_xl_modular.py | 466 +++++++++++------- 1 file changed, 299 insertions(+), 167 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 81808540ee..ea77428343 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2395,27 +2395,20 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): return components, state -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): +class StableDiffusionXLControlNetInputStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config"), - ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + return "step that prepare inputs for controlnet" @property def inputs(self) -> List[Tuple[str, Any]]: @@ -2426,9 +2419,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("cross_attention_kwargs"), - InputParam("generator"), - InputParam("eta", default=0.0), ] @property @@ -2452,12 +2442,246 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("control_image", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("controlnet_conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + + + self.add_block_state(state, block_state) + + return components, state + +class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), + InputParam("num_images_per_prompt", default=1), + InputParam("cross_attention_kwargs"), + InputParam("generator"), + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "control_image", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), InputParam( "prompt_embeds", required=True, @@ -2557,36 +2781,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): " `components.unet` or your `mask_image` or `image` input." ) - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod @@ -2613,103 +2807,19 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): block_state = self.get_block_state(state) self.check_inputs(components, block_state) - - block_state.num_channels_unet = components.unet.config.in_channels - - # (1) prepare controlnet inputs block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(components.controlnet) - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - block_state.control_guidance_start, block_state.control_guidance_end = ( - mult * [block_state.control_guidance_start], - mult * [block_state.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( - components, - image=block_state.control_image, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in block_state.control_image: - control_image = self.prepare_control_image( - components, - image=control_image_, - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - - control_images.append(control_image) - - block_state.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - keeps = [ - 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) - for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) - ] - block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # (2) Prepare conditional inputs for unet using the guider + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + # (1) setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - components.guider.set_input_fields( prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), add_time_ids=("add_time_ids", "negative_add_time_ids"), @@ -2720,11 +2830,16 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # (5) Denoise loop with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_data = components.guider.prepare_inputs(block_state) + # prepare latent input for unet block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + # adjust latent input for inpainting + block_state.num_channels_unet = components.unet.config.in_channels + if block_state.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + # cond_scale (controlnet input) if isinstance(block_state.controlnet_keep[i], list): block_state.cond_scale = [c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])] else: @@ -2733,62 +2848,69 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] - for batch in guider_data: + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(block_state) + + for guider_state_batch in guider_state: components.guider.prepare_models(components.unet) # Prepare additional conditionings - batch.added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - if batch.ip_adapter_embeds is not None: - batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + if guider_state_batch.ip_adapter_embeds is not None: + guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds # Prepare controlnet additional conditionings - batch.controlnet_added_cond_kwargs = { - "text_embeds": batch.pooled_prompt_embeds, - "time_ids": batch.add_time_ids, + guider_state_batch.controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.pooled_prompt_embeds, + "time_ids": guider_state_batch.add_time_ids, } - # Will always be run atleast once with every guider - if components.guider.is_conditional or not block_state.guess_mode: - block_state.down_block_res_samples, block_state.mid_block_res_sample = components.controlnet( + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros + guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, controlnet_cond=block_state.control_image, conditioning_scale=block_state.cond_scale, guess_mode=block_state.guess_mode, - added_cond_kwargs=batch.controlnet_added_cond_kwargs, + added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, return_dict=False, ) - batch.down_block_res_samples = block_state.down_block_res_samples - batch.mid_block_res_sample = block_state.mid_block_res_sample + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) - if components.guider.is_unconditional and block_state.guess_mode: - batch.down_block_res_samples = [torch.zeros_like(d) for d in block_state.down_block_res_samples] - batch.mid_block_res_sample = torch.zeros_like(block_state.mid_block_res_sample) - # Prepare for inpainting - if block_state.num_channels_unet == 9: - block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) - - batch.noise_pred = components.unet( + + guider_state_batch.noise_pred = components.unet( block_state.scaled_latents, t, - encoder_hidden_states=batch.prompt_embeds, + encoder_hidden_states=guider_state_batch.prompt_embeds, timestep_cond=block_state.timestep_cond, cross_attention_kwargs=block_state.cross_attention_kwargs, - added_cond_kwargs=batch.added_cond_kwargs, - down_block_additional_residuals=batch.down_block_res_samples, - mid_block_additional_residual=batch.mid_block_res_sample, + added_cond_kwargs=guider_state_batch.added_cond_kwargs, + down_block_additional_residuals=guider_state_batch.down_block_res_samples, + mid_block_additional_residual=guider_state_batch.mid_block_res_sample, return_dict=False, )[0] components.guider.cleanup_models(components.unet) # Perform guidance - block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) # Perform scheduler step using the predicted output block_state.latents_dtype = block_state.latents.dtype @@ -2799,6 +2921,7 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 block_state.latents = block_state.latents.to(block_state.latents_dtype) + # adjust latent for inpainting if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: block_state.init_latents_proper = block_state.image_latents if i < len(block_state.timesteps) - 1: @@ -3463,6 +3586,16 @@ class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" +class StableDiffusionXLControlNetStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetInputStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "Controlnet step that denoise the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetDenoiseStep` is used to denoise the latents." class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] @@ -3477,10 +3610,9 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3489,7 +3621,7 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3597,7 +3729,7 @@ INPAINT_BLOCKS = OrderedDict([ ]) CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), + ("denoise", StableDiffusionXLControlNetStep), ]) CONTROLNET_UNION_BLOCKS = OrderedDict([