1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update more

This commit is contained in:
yiyixuxu
2025-08-06 21:57:22 +02:00
parent dc6a4d4cb4
commit fe2a6a35e9
2 changed files with 50 additions and 47 deletions

View File

@@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
),
InputParam(
"processed_mask_image",
"mask_latents",
required=True,
type_hint=torch.Tensor,
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
@@ -643,44 +643,43 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
def check_inputs(self, image_latents, mask, masked_image_latents):
def check_inputs(self, batch_size, image_latents, mask_latents, masked_image_latents):
if image_latents.shape[0] != 1:
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
if mask.shape[0] != 1:
raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}")
if masked_image_latents is not None and masked_image_latents.shape[0] != 1:
raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}")
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
if not (mask_latents.shape[0] == 1 or mask_latents.shape[0] == batch_size):
raise ValueError(f"mask_latents should have have batch size 1 or {batch_size}, but got {mask_latents.shape[0]}")
if latent_timestep is not None and len(latent_timestep.shape) > 0:
raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}")
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep)
self.check_inputs(
batch_size=block_state.batch_size,
image_latents=block_state.image_latents,
mask_latents=block_state.mask_latents,
masked_image_latents=block_state.masked_image_latents,
)
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
device = components._execution_device
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
# 7. Prepare mask latent variables
block_state.mask = torch.nn.functional.interpolate(
block_state.mask, size=(height_latents, width_latents)
)
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1)
block_state.mask_latents = block_state.mask_latents.to(device=device, dtype=dtype)
block_state.mask_latents = block_state.mask_latents.repeat(final_batch_size//block_state.mask_latents.shape[0], 1, 1, 1)
if block_state.masked_image_latents is not None:
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1)
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
@@ -758,9 +757,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
)
]
def check_inputs(self, image_latents):
if image_latents.shape[0] != 1:
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
def check_inputs(self, batch_size, image_latents):
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
@@ -778,13 +777,18 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
batch_size=block_state.batch_size,
image_latents=block_state.image_latents,
)
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
device = components._execution_device
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
@@ -793,9 +797,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
add_noise = True if block_state.denoising_start is None else False
if add_noise:
block_state.latents = prepare_latents(
block_state.image_latents,
components.scheduler,
block_state.latents = self.prepare_latents(
image_latents=block_state.image_latents,
scheduler=components.scheduler,
timestep=block_state.latent_timestep,
dtype=dtype,
device=device,

View File

@@ -594,11 +594,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
),
OutputParam(
"processed_image",
type_hint=PIL.Image.Image,
description="The preprocessed image",
)
]
@@ -686,14 +681,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
OutputParam(
"processed_image",
type_hint=PIL.Image.Image,
description="The preprocessed image",
),
OutputParam(
"processed_mask_image",
"mask_latents",
type_hint=torch.Tensor,
description="The preprocessed mask image",
description="The mask to apply on the latents for the inpainting generation.",
),
]
@@ -733,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
crops_coords = None
resize_mode = "default"
block_state.processed_image = components.image_processor.preprocess(
processed_image = components.image_processor.preprocess(
block_state.image,
height=height,
width=width,
@@ -741,9 +731,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
resize_mode=resize_mode,
)
block_state.processed_image = block_state.processed_image.to(dtype=torch.float32)
processed_image = processed_image.to(dtype=torch.float32)
block_state.processed_mask_image = components.mask_processor.preprocess(
processed_mask_image = components.mask_processor.preprocess(
block_state.mask_image,
height=height,
width=width,
@@ -751,17 +741,18 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
crops_coords=crops_coords,
)
masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5)
masked_image = processed_image * (block_state.mask_latents < 0.5)
# Prepare image latent variables
block_state.image_latents = encode_vae_image(
image=block_state.processed_image,
image=processed_image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
)
# 7. Prepare mask latent variables
# Prepare masked image latent variables
block_state.masked_image_latents = encode_vae_image(
image=masked_image,
vae=components.vae,
@@ -769,6 +760,14 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
dtype=dtype,
device=device
)
# resize mask to match the image latents
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.mask_latents = torch.nn.functional.interpolate(
processed_mask_image,
size=(height_latents, width_latents),
)
block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device)
self.set_block_state(state, block_state)