diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index e7b7d2550e..516a517897 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", ), InputParam( - "processed_mask_image", + "mask_latents", required=True, type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step.", @@ -643,44 +643,43 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - def check_inputs(self, image_latents, mask, masked_image_latents): + def check_inputs(self, batch_size, image_latents, mask_latents, masked_image_latents): - if image_latents.shape[0] != 1: - raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}") - if mask.shape[0] != 1: - raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}") - if masked_image_latents is not None and masked_image_latents.shape[0] != 1: - raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}") + if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): + raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") + + if not (mask_latents.shape[0] == 1 or mask_latents.shape[0] == batch_size): + raise ValueError(f"mask_latents should have have batch size 1 or {batch_size}, but got {mask_latents.shape[0]}") - if latent_timestep is not None and len(latent_timestep.shape) > 0: - raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}") + if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size): + raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}") @torch.no_grad() def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep) + self.check_inputs( + batch_size=block_state.batch_size, + image_latents=block_state.image_latents, + mask_latents=block_state.mask_latents, + masked_image_latents=block_state.masked_image_latents, + ) dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype device = components._execution_device - + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt - _, _, height_latents, width_latents = block_state.image_latents.shape block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) - block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1) + block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1) # 7. Prepare mask latent variables - block_state.mask = torch.nn.functional.interpolate( - block_state.mask, size=(height_latents, width_latents) - ) - block_state.mask = block_state.mask.to(device=device, dtype=dtype) - block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1) + block_state.mask_latents = block_state.mask_latents.to(device=device, dtype=dtype) + block_state.mask_latents = block_state.mask_latents.repeat(final_batch_size//block_state.mask_latents.shape[0], 1, 1, 1) - if block_state.masked_image_latents is not None: - block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) - block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1) + block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype) + block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1) if block_state.latent_timestep is not None: block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) @@ -758,9 +757,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ) ] - def check_inputs(self, image_latents): - if image_latents.shape[0] != 1: - raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}") + def check_inputs(self, batch_size, image_latents): + if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size): + raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}") def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None): if isinstance(generator, list) and len(generator) != image_latents.shape[0]: @@ -778,13 +777,18 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + self.check_inputs( + batch_size=block_state.batch_size, + image_latents=block_state.image_latents, + ) + dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype device = components._execution_device final_batch_size = block_state.batch_size * block_state.num_images_per_prompt block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype) - block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1) + block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1) if block_state.latent_timestep is not None: block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size) @@ -793,9 +797,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): add_noise = True if block_state.denoising_start is None else False if add_noise: - block_state.latents = prepare_latents( - block_state.image_latents, - components.scheduler, + block_state.latents = self.prepare_latents( + image_latents=block_state.image_latents, + scheduler=components.scheduler, timestep=block_state.latent_timestep, dtype=dtype, device=device, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index ac6ebe78fa..186fad0e33 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -594,11 +594,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): "image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation", - ), - OutputParam( - "processed_image", - type_hint=PIL.Image.Image, - description="The preprocessed image", ) ] @@ -686,14 +681,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): description="The crop coordinates to use for the preprocess/postprocess of the image and mask", ), OutputParam( - "processed_image", - type_hint=PIL.Image.Image, - description="The preprocessed image", - ), - OutputParam( - "processed_mask_image", + "mask_latents", type_hint=torch.Tensor, - description="The preprocessed mask image", + description="The mask to apply on the latents for the inpainting generation.", ), ] @@ -733,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): crops_coords = None resize_mode = "default" - block_state.processed_image = components.image_processor.preprocess( + processed_image = components.image_processor.preprocess( block_state.image, height=height, width=width, @@ -741,9 +731,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): resize_mode=resize_mode, ) - block_state.processed_image = block_state.processed_image.to(dtype=torch.float32) + processed_image = processed_image.to(dtype=torch.float32) - block_state.processed_mask_image = components.mask_processor.preprocess( + processed_mask_image = components.mask_processor.preprocess( block_state.mask_image, height=height, width=width, @@ -751,17 +741,18 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): crops_coords=crops_coords, ) - masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5) + masked_image = processed_image * (block_state.mask_latents < 0.5) + # Prepare image latent variables block_state.image_latents = encode_vae_image( - image=block_state.processed_image, + image=processed_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device ) - # 7. Prepare mask latent variables + # Prepare masked image latent variables block_state.masked_image_latents = encode_vae_image( image=masked_image, vae=components.vae, @@ -769,6 +760,14 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): dtype=dtype, device=device ) + + # resize mask to match the image latents + _, _, height_latents, width_latents = block_state.image_latents.shape + block_state.mask_latents = torch.nn.functional.interpolate( + processed_mask_image, + size=(height_latents, width_latents), + ) + block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device) self.set_block_state(state, block_state)