diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index ea77428343..5ebdd383cc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2613,12 +2613,6 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2755,6 +2749,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), ] @property @@ -2940,25 +2940,198 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): return components, state +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", default=[0]), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("control_image", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_mode", type_hint=List[int], description="The control mode indices"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("controlnet_conditioning_scale", type_hint=float, description="The controlnet conditioning scale value"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = block_state.device or components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + + + self.add_block_state(state, block_state) + + return components, state + class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), - ComponentSpec( - "control_image_processor", - VaeImageProcessor, - config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), - default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetUnionModel), ] @property @@ -2967,12 +3140,6 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("control_image", required=True), - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("control_mode", required=True), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), InputParam("cross_attention_kwargs"), InputParam("generator"), @@ -2983,15 +3150,75 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", + "control_image", + required=True, + type_hint=List[torch.Tensor], + description="The control images to use for conditioning. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_mode", + required=True, + type_hint=List[int], + description="The control mode indices. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_type", required=True, type_hint=torch.Tensor, + description="The control type tensor that specifies which control type is active. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "num_control_type", + required=True, + type_hint=int, + description="The number of control types available. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_start", + required=True, + type_hint=float, + description="The control guidance start value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "control_guidance_end", + required=True, + type_hint=float, + description="The control guidance end value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_conditioning_scale", + required=True, + type_hint=float, + description="The controlnet conditioning scale value. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="Whether guess mode is used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "global_pool_conditions", + required=True, + type_hint=bool, + description="Whether global pool conditions are used. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values. Can be generated in prepare controlnet inputs step." + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( "batch_size", - required=True, - type_hint=int, + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( @@ -3045,23 +3272,23 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( @@ -3070,19 +3297,19 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), ] @property - def intermediates_outputs(self) -> List[str]: + def intermediates_outputs(self) -> List[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @staticmethod @@ -3105,39 +3332,7 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): " `components.unet` or your `mask_image` or `image` input." ) - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - @staticmethod - def prepare_control_image( - components, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @staticmethod @@ -3164,85 +3359,20 @@ class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): self.check_inputs(components, block_state) block_state.num_channels_unet = components.unet.config.in_channels - - # (1) prepare controlnet inputs block_state.device = components._execution_device - block_state.height, block_state.width = block_state.latents.shape[-2:] - block_state.height = block_state.height * components.vae_scale_factor - block_state.width = block_state.width * components.vae_scale_factor - controlnet = unwrap_module(components.controlnet) + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - # (1.1) - # control guidance - if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): - block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] - elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): - block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions - - # (1.3) - # control_type - block_state.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(block_state.control_image, list): - block_state.control_image = [block_state.control_image] - - if not isinstance(block_state.control_mode, list): - block_state.control_mode = [block_state.control_mode] - - if len(block_state.control_image) != len(block_state.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - block_state.control_type = [0 for _ in range(block_state.num_control_type)] - for control_idx in block_state.control_mode: - block_state.control_type[control_idx] = 1 - - block_state.control_type = torch.Tensor(block_state.control_type) - - # (1.5) - # prepare control_image - for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( - components, - image=block_state.control_image[idx], - width=block_state.width, - height=block_state.height, - batch_size=block_state.batch_size * block_state.num_images_per_prompt, - num_images_per_prompt=block_state.num_images_per_prompt, - device=block_state.device, - dtype=controlnet.dtype, - crops_coords=block_state.crops_coords, - ) - block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] - - # (1.6) - # controlnet_keep - block_state.controlnet_keep = [] - for i in range(len(block_state.timesteps)): - block_state.controlnet_keep.append( - 1.0 - - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) - ) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale + # Setup guider + # disable for LCMs block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False if block_state.disable_guidance: components.guider.disable() else: components.guider.enable() - block_state.control_type = block_state.control_type.reshape(1, -1).to(block_state.device, dtype=block_state.prompt_embeds.dtype) - repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] - block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) @@ -3612,7 +3742,7 @@ class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): # Denoise class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] + block_classes = [StableDiffusionXLControlNetUnionStep, StableDiffusionXLControlNetStep, StableDiffusionXLDenoiseStep] block_names = ["controlnet_union", "controlnet", "unet"] block_trigger_inputs = ["control_mode", "control_image", None] @@ -3620,8 +3750,8 @@ class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): def description(self): return "Denoise step that denoise the latents.\n" + \ "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlStep` (controlnet) is used when `control_image` is provided.\n" + \ + " - `StableDiffusionXLControlNetUnionStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetStep` (controlnet) is used when `control_image` is provided.\n" + \ " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." # After denoise @@ -3733,7 +3863,7 @@ CONTROLNET_BLOCKS = OrderedDict([ ]) CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), + ("denoise", StableDiffusionXLControlNetUnionStep), ]) IP_ADAPTER_BLOCKS = OrderedDict([ @@ -3865,3 +3995,15 @@ SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") } + + +class StableDiffusionXLControlNetUnionStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetUnionDenoiseStep] + block_names = ["prepare_input", "denoise"] + + @property + def description(self): + return "ControlNetUnion step that denoises the latents.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is used to prepare the inputs for the denoise step.\n" + \ + " - `StableDiffusionXLControlNetUnionDenoiseStep` is used to denoise the latents using the ControlNetUnion model."