diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 34c5cc275d..3e5301987e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -607,8 +607,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ), ] + @staticmethod def prepare_latents( - self, image_latents, scheduler, dtype, @@ -760,6 +760,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") + @staticmethod def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None): if isinstance(generator, list) and len(generator) != image_latents.shape[0]: raise ValueError( @@ -975,6 +976,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."), ] @property @@ -992,7 +994,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process", ), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), ] @staticmethod @@ -1136,6 +1137,11 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), ] @property @@ -1187,8 +1193,8 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): _, _, height_latents, width_latents = block_state.latents.shape height = height_latents * components.vae_scale_factor width = width_latents * components.vae_scale_factor - original_size = block_state.original_size or (block_state.height, block_state.width) - target_size = block_state.target_size or (block_state.height, block_state.width) + original_size = block_state.original_size or (height, width) + target_size = block_state.target_size or (height, width) block_state.add_time_ids = self._get_add_time_ids( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index e312f9c860..6ed5d72bf6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -139,7 +139,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock): block_state.images = components.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed - if block_state.needs_upcasting: + if needs_upcasting: components.vae.to(dtype=torch.float16) else: block_state.images = block_state.latents diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 28ece71453..ea95d850d0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -39,6 +39,8 @@ from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusionXLModularPipeline +from PIL import Image + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -504,11 +506,11 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0) - prompt_embeds = prompt_embeds.to(dtype, device=device) - pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) if requires_unconditional_embeds: - negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device) - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device) for text_encoder in text_encoders: if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -687,12 +689,12 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): def check_inputs(self, image, mask_image, padding_mask_crop): - if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image): + if padding_mask_crop is not None and not isinstance(image, Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." ) - if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image): + if padding_mask_crop is not None and not isinstance(mask_image, Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." @@ -707,10 +709,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype device = components._execution_device - if block_state.height is None: - height = components.default_height - if block_state.width is None: - width = components.default_width + height = block_state.height if block_state.height is not None else components.default_height + width = block_state.width if block_state.width is not None else components.default_width if block_state.padding_mask_crop is not None: block_state.crops_coords = components.mask_processor.get_crop_region( @@ -725,21 +725,21 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): block_state.image, height=height, width=width, - crops_coords=crops_coords, + crops_coords=block_state.crops_coords, resize_mode=resize_mode, ) image = image.to(dtype=torch.float32) - mask = components.mask_processor.preprocess( + mask_image = components.mask_processor.preprocess( block_state.mask_image, height=height, width=width, resize_mode=resize_mode, - crops_coords=crops_coords, + crops_coords=block_state.crops_coords, ) - masked_image = image * (block_state.mask_latents < 0.5) + masked_image = image * (mask_image < 0.5) # Prepare image latent variables block_state.image_latents = encode_vae_image( @@ -762,7 +762,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): # resize mask to match the image latents _, _, height_latents, width_latents = block_state.image_latents.shape block_state.mask = torch.nn.functional.interpolate( - mask, + mask_image, size=(height_latents, width_latents), ) block_state.mask = block_state.mask.to(dtype=dtype, device=device) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 2a52c70176..f337a89a8b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -95,11 +95,15 @@ class StableDiffusionXLModularPipeline( # by default, always prepare unconditional embeddings requires_unconditional_embeds = True - if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None: + if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None: + # LCM requires_unconditional_embeds = False elif hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider.num_conditions > 1 + + elif not hasattr(self, "guider") or self.guider is None: + requires_unconditional_embeds = False return requires_unconditional_embeds