From ed881a15fddd00f95725aaa8ed099291cdda612d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 7 Aug 2025 23:34:22 +0200 Subject: [PATCH] up up up --- .../stable_diffusion_xl/before_denoise.py | 257 +++++++++--------- .../stable_diffusion_xl/decoders.py | 16 +- .../stable_diffusion_xl/denoise.py | 14 +- .../stable_diffusion_xl/encoders.py | 32 +-- .../stable_diffusion_xl/modular_blocks.py | 22 ++ .../stable_diffusion_xl/modular_pipeline.py | 5 +- 6 files changed, 187 insertions(+), 159 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 516a517897..34c5cc275d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", ), InputParam( - "mask_latents", + "mask", required=True, type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step.", @@ -607,11 +607,10 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ), ] - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument def prepare_latents( self, - components, image_latents, + scheduler, dtype, device, generator, @@ -631,9 +630,9 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): if add_noise: noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + latents = latents * scheduler.init_noise_sigma if is_strength_max else latents else: noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) @@ -643,13 +642,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - def check_inputs(self, batch_size, image_latents, mask_latents, masked_image_latents): + def check_inputs(self, batch_size, image_latents, mask, masked_image_latents): if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") - if not (mask_latents.shape[0] == 1 or mask_latents.shape[0] == batch_size): - raise ValueError(f"mask_latents should have have batch size 1 or {batch_size}, but got {mask_latents.shape[0]}") + if not (mask.shape[0] == 1 or mask.shape[0] == batch_size): + raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}") if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size): raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}") @@ -662,7 +661,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): self.check_inputs( batch_size=block_state.batch_size, image_latents=block_state.image_latents, - mask_latents=block_state.mask_latents, + mask=block_state.mask, masked_image_latents=block_state.masked_image_latents, ) @@ -675,8 +674,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1) # 7. Prepare mask latent variables - block_state.mask_latents = block_state.mask_latents.to(device=device, dtype=dtype) - block_state.mask_latents = block_state.mask_latents.repeat(final_batch_size//block_state.mask_latents.shape[0], 1, 1, 1) + block_state.mask = block_state.mask.to(device=device, dtype=dtype) + block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1) block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1) @@ -689,8 +688,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): add_noise = True if block_state.denoising_start is None else False block_state.latents, block_state.noise = self.prepare_latents( - components=components, image_latents=block_state.image_latents, + scheduler=components.scheduler, dtype=dtype, device=device, generator=block_state.generator, @@ -945,15 +944,14 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("original_size"), - InputParam("target_size"), InputParam("negative_original_size"), + InputParam("target_size"), InputParam("negative_target_size"), InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), - InputParam("num_images_per_prompt", default=1), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), - InputParam("embedded_guidance_scale", default=7.5), + InputParam("num_images_per_prompt", default=1), ] @property @@ -1050,42 +1048,12 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): return add_time_ids, add_neg_time_ids - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + device = components._execution_device - dtype = block_state.pooled_prompt_embeds.dtype + dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype final_batch_size = block_state.batch_size * block_state.num_images_per_prompt text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) @@ -1120,19 +1088,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) - # Optionally get Guidance Scale Embedding for LCM - block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=device, dtype=dtype) - self.set_block_state(state, block_state) return components, state @@ -1158,7 +1113,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("embedded_guidance_scale", default=7.5), ] @property @@ -1199,7 +1153,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process", ), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] @staticmethod @@ -1222,6 +1175,84 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype + text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt + + _, _, height_latents, width_latents = block_state.latents.shape + height = height_latents * components.vae_scale_factor + width = width_latents * components.vae_scale_factor + original_size = block_state.original_size or (block_state.height, block_state.width) + target_size = block_state.target_size or (block_state.height, block_state.width) + + + block_state.add_time_ids = self._get_add_time_ids( + components, + original_size, + block_state.crops_coords_top_left, + target_size, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLLCMStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("unet", UNet2DConditionModel),] + + @property + def description(self) -> str: + return "Step that prepares the timestep cond input for latent consistency models" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam("embedded_guidance_scale"), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 @@ -1253,57 +1284,33 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): assert emb.shape == (w.shape[0], embedding_dim) return emb + + def check_input(self, unet, embedded_guidance_scale): + + if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None: + raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None") + + if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None: + raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None") + + @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + device = components._execution_device - dtype = block_state.pooled_prompt_embeds.dtype - text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype final_batch_size = block_state.batch_size * block_state.num_images_per_prompt - _, _, height_latents, width_latents = block_state.latents.shape - height = height_latents * components.vae_scale_factor - width = width_latents * components.vae_scale_factor - block_state.original_size = block_state.original_size or (block_state.height, block_state.width) - block_state.target_size = block_state.target_size or (block_state.height, block_state.width) - - - block_state.add_time_ids = self._get_add_time_ids( - components, - block_state.original_size, - block_state.crops_coords_top_left, - block_state.target_size, - dtype=dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - if block_state.negative_original_size is not None and block_state.negative_target_size is not None: - block_state.negative_add_time_ids = self._get_add_time_ids( - components, - block_state.negative_original_size, - block_state.negative_crops_coords_top_left, - block_state.negative_target_size, - dtype=dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - block_state.negative_add_time_ids = block_state.add_time_ids - - block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device) - block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device) # Optionally get Guidance Scale Embedding for LCM block_state.timestep_cond = None - if ( - hasattr(components, "unet") - and components.unet is not None - and components.unet.config.time_cond_proj_dim is not None - ): - # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! - block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) - block_state.timestep_cond = self.get_guidance_scale_embedding( - block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim - ).to(device=block_state.device, dtype=block_state.latents.dtype) + + guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device) + block_state.timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=device, dtype=dtype) self.set_block_state(state, block_state) return components, state @@ -1459,29 +1466,29 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): ) # (1.2) - # controlnet_conditioning_scale (align format) + # conditioning_scale (align format) if isinstance(controlnet, MultiControlNetModel) and isinstance( block_state.controlnet_conditioning_scale, float ): - block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( + block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len( controlnet.nets ) + else: + block_state.conditioning_scale = block_state.controlnet_conditioning_scale # (1.3) - # global_pool_conditions - block_state.global_pool_conditions = ( + # guess_mode + global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) - # (1.4) - # guess_mode - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or global_pool_conditions - # (1.5) - # control_image + # (1.4) + # controlnet_cond if isinstance(controlnet, ControlNetModel): - block_state.control_image = self.prepare_control_image( + block_state.controlnet_cond = self.prepare_control_image( components, image=block_state.control_image, width=width, @@ -1510,7 +1517,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): control_images.append(control_image) - block_state.control_image = control_images + block_state.controlnet_cond = control_images else: assert False @@ -1524,9 +1531,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): ] block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - block_state.controlnet_cond = block_state.control_image - block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.set_block_state(state, block_state) return components, state @@ -1686,8 +1690,8 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): ] # guess_mode - block_state.global_pool_conditions = controlnet.config.global_pool_conditions - block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or global_pool_conditions # control_image if not isinstance(block_state.control_image, list): @@ -1700,19 +1704,20 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): raise ValueError("Expected len(control_image) == len(control_type)") # control_type - block_state.num_control_type = controlnet.config.num_control_type - block_state.control_type = [0 for _ in range(block_state.num_control_type)] + num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(num_control_type)] for control_idx in block_state.control_mode: block_state.control_type[control_idx] = 1 block_state.control_type = torch.Tensor(block_state.control_type) - block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=dtype) repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) - # prepare control_image + # prepare controlnet_cond + block_state.controlnet_cond = [] for idx, _ in enumerate(block_state.control_image): - block_state.control_image[idx] = self.prepare_control_image( + control_image = self.prepare_control_image( components, image=block_state.control_image[idx], width=width, @@ -1723,7 +1728,8 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): dtype=dtype, crops_coords=block_state.crops_coords, ) - _, _, height, width = block_state.control_image[idx].shape + _, _, height, width = control_image.shape + block_state.controlnet_cond.append(control_image) # controlnet_keep block_state.controlnet_keep = [] @@ -1736,7 +1742,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): ) ) block_state.control_type_idx = block_state.control_mode - block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index e9f627636e..e312f9c860 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -105,9 +105,9 @@ class StableDiffusionXLDecodeStep(PipelineBlock): if not block_state.output_type == "latent": latents = block_state.latents # make sure the VAE is in float32 mode, as it overflows in float16 - block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast - if block_state.needs_upcasting: + if needs_upcasting: self.upcast_vae(components) latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != components.vae.dtype: @@ -117,21 +117,21 @@ class StableDiffusionXLDecodeStep(PipelineBlock): # unscale/denormalize the latents # denormalize with the mean and std if available and not None - block_state.has_latents_mean = ( + has_latents_mean = ( hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None ) - block_state.has_latents_std = ( + has_latents_std = ( hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None ) - if block_state.has_latents_mean and block_state.has_latents_std: - block_state.latents_mean = ( + if has_latents_mean and has_latents_std: + latents_mean = ( torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - block_state.latents_std = ( + latents_std = ( torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = ( - latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + latents * latents_std / components.vae.config.scaling_factor + latents_mean ) else: latents = latents / components.vae.config.scaling_factor diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 7fe4a472ee..871fafd024 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -67,7 +67,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t) return components, block_state @@ -134,10 +134,10 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) - block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t) if components.num_channels_unet == 9: - block_state.scaled_latents = torch.cat( - [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + block_state.latent_model_input = torch.cat( + [block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1 ) return components, block_state @@ -232,7 +232,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=block_state.timestep_cond, @@ -410,7 +410,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): mid_block_res_sample = block_state.mid_block_res_sample_zeros else: down_block_res_samples, mid_block_res_sample = components.controlnet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=guider_state_batch.prompt_embeds, controlnet_cond=block_state.controlnet_cond, @@ -430,7 +430,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): # Predict the noise # store the noise_pred in guider_state_batch so we can apply guidance across all batches guider_state_batch.noise_pred = components.unet( - block_state.scaled_latents, + block_state.latent_model_input, t, encoder_hidden_states=guider_state_batch.prompt_embeds, timestep_cond=block_state.timestep_cond, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 186fad0e33..28ece71453 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -390,7 +390,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ - device = device or components._execution_device dtype = components.text_encoder_2.dtype @@ -526,7 +525,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2) device = components._execution_device - dtype = components.text_encoder_2.dtype # Encode input prompt lora_scale = ( @@ -542,8 +540,8 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ) = self.encode_prompt( components, prompt=block_state.prompt, - prompt2=block_state.prompt_2, - device = device, + prompt_2=block_state.prompt_2, + device=device, requires_unconditional_embeds=components.requires_unconditional_embeds, negative_prompt=block_state.negative_prompt, negative_prompt_2=block_state.negative_prompt_2, @@ -604,11 +602,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): device = components._execution_device dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.processed_image = components.image_processor.preprocess(block_state.image) + image = components.image_processor.preprocess(block_state.image) # Encode image into latents block_state.image_latents = encode_vae_image( - image=block_state.processed_image, + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, @@ -681,7 +679,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): description="The crop coordinates to use for the preprocess/postprocess of the image and mask", ), OutputParam( - "mask_latents", + "mask", type_hint=torch.Tensor, description="The mask to apply on the latents for the inpainting generation.", ), @@ -715,15 +713,15 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): width = components.default_width if block_state.padding_mask_crop is not None: - crops_coords = components.mask_processor.get_crop_region( + block_state.crops_coords = components.mask_processor.get_crop_region( mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop ) resize_mode = "fill" else: - crops_coords = None + block_state.crops_coords = None resize_mode = "default" - processed_image = components.image_processor.preprocess( + image = components.image_processor.preprocess( block_state.image, height=height, width=width, @@ -731,9 +729,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): resize_mode=resize_mode, ) - processed_image = processed_image.to(dtype=torch.float32) + image = image.to(dtype=torch.float32) - processed_mask_image = components.mask_processor.preprocess( + mask = components.mask_processor.preprocess( block_state.mask_image, height=height, width=width, @@ -741,11 +739,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): crops_coords=crops_coords, ) - masked_image = processed_image * (block_state.mask_latents < 0.5) + masked_image = image * (block_state.mask_latents < 0.5) # Prepare image latent variables block_state.image_latents = encode_vae_image( - image=processed_image, + image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, @@ -763,11 +761,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): # resize mask to match the image latents _, _, height_latents, width_latents = block_state.image_latents.shape - block_state.mask_latents = torch.nn.functional.interpolate( - processed_mask_image, + block_state.mask = torch.nn.functional.interpolate( + mask, size=(height_latents, width_latents), ) - block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device) + block_state.mask = block_state.mask.to(dtype=dtype, device=device) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py index c9033856bc..c38eb8c632 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py @@ -26,6 +26,7 @@ from .before_denoise import ( StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, + StableDiffusionXLLCMStep, ) from .decoders import ( StableDiffusionXLDecodeStep, @@ -79,6 +80,16 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" +class StableDiffusionXLAutoLCMStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLLCMStep] + block_names = ["lcm"] + block_trigger_inputs = ["embedded_guidance_scale"] + + @property + def description(self): + return "Run LCM step if `latents` is provided. This step should be placed before the 'input' step.\n" + + # before_denoise: text2img class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ @@ -262,6 +273,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoLCMStep, StableDiffusionXLAutoControlNetInputStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep, @@ -271,6 +283,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): "ip_adapter", "image_encoder", "before_denoise", + "lcm", "controlnet_input", "denoise", "decoder", @@ -286,6 +299,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + "- for text-to-image generation, all you need to provide is `prompt`" + + "- to run the latent consistency models workflow, you need to provide `embedded_guidance_scale`" ) @@ -357,6 +371,13 @@ IP_ADAPTER_BLOCKS = InsertableDict( ] ) +LCM_BLOCKS = InsertableDict( + + [ + ("lcm", StableDiffusionXLAutoLCMStep), + ] +) + AUTO_BLOCKS = InsertableDict( [ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -376,5 +397,6 @@ ALL_BLOCKS = { "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, "ip_adapter": IP_ADAPTER_BLOCKS, + "lcm": LCM_BLOCKS, "auto": AUTO_BLOCKS, } diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 84dd0c0ee3..2a52c70176 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -95,7 +95,10 @@ class StableDiffusionXLModularPipeline( # by default, always prepare unconditional embeddings requires_unconditional_embeds = True - if hasattr(self, "guider") and self.guider is not None: + if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None: + requires_unconditional_embeds = False + + elif hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider.num_conditions > 1 return requires_unconditional_embeds