mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update more
This commit is contained in:
@@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"processed_mask_image",
|
||||
"mask_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
|
||||
@@ -643,44 +643,43 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
|
||||
|
||||
def check_inputs(self, image_latents, mask, masked_image_latents):
|
||||
def check_inputs(self, batch_size, image_latents, mask_latents, masked_image_latents):
|
||||
|
||||
if image_latents.shape[0] != 1:
|
||||
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
|
||||
if mask.shape[0] != 1:
|
||||
raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}")
|
||||
if masked_image_latents is not None and masked_image_latents.shape[0] != 1:
|
||||
raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}")
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
|
||||
if not (mask_latents.shape[0] == 1 or mask_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"mask_latents should have have batch size 1 or {batch_size}, but got {mask_latents.shape[0]}")
|
||||
|
||||
if latent_timestep is not None and len(latent_timestep.shape) > 0:
|
||||
raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}")
|
||||
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep)
|
||||
self.check_inputs(
|
||||
batch_size=block_state.batch_size,
|
||||
image_latents=block_state.image_latents,
|
||||
mask_latents=block_state.mask_latents,
|
||||
masked_image_latents=block_state.masked_image_latents,
|
||||
)
|
||||
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
|
||||
device = components._execution_device
|
||||
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
block_state.mask, size=(height_latents, width_latents)
|
||||
)
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
|
||||
block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1)
|
||||
block_state.mask_latents = block_state.mask_latents.to(device=device, dtype=dtype)
|
||||
block_state.mask_latents = block_state.mask_latents.repeat(final_batch_size//block_state.mask_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if block_state.masked_image_latents is not None:
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
@@ -758,9 +757,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
)
|
||||
]
|
||||
|
||||
def check_inputs(self, image_latents):
|
||||
if image_latents.shape[0] != 1:
|
||||
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
|
||||
def check_inputs(self, batch_size, image_latents):
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
|
||||
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
|
||||
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
|
||||
@@ -778,13 +777,18 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
batch_size=block_state.batch_size,
|
||||
image_latents=block_state.image_latents,
|
||||
)
|
||||
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
|
||||
device = components._execution_device
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
@@ -793,9 +797,9 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
add_noise = True if block_state.denoising_start is None else False
|
||||
|
||||
if add_noise:
|
||||
block_state.latents = prepare_latents(
|
||||
block_state.image_latents,
|
||||
components.scheduler,
|
||||
block_state.latents = self.prepare_latents(
|
||||
image_latents=block_state.image_latents,
|
||||
scheduler=components.scheduler,
|
||||
timestep=block_state.latent_timestep,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
||||
@@ -594,11 +594,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
description="The preprocessed image",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -686,14 +681,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
description="The preprocessed image",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_mask_image",
|
||||
"mask_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The preprocessed mask image",
|
||||
description="The mask to apply on the latents for the inpainting generation.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -733,7 +723,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
processed_image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -741,9 +731,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
block_state.processed_image = block_state.processed_image.to(dtype=torch.float32)
|
||||
processed_image = processed_image.to(dtype=torch.float32)
|
||||
|
||||
block_state.processed_mask_image = components.mask_processor.preprocess(
|
||||
processed_mask_image = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -751,17 +741,18 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
crops_coords=crops_coords,
|
||||
)
|
||||
|
||||
masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5)
|
||||
masked_image = processed_image * (block_state.mask_latents < 0.5)
|
||||
|
||||
# Prepare image latent variables
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=block_state.processed_image,
|
||||
image=processed_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
# Prepare masked image latent variables
|
||||
block_state.masked_image_latents = encode_vae_image(
|
||||
image=masked_image,
|
||||
vae=components.vae,
|
||||
@@ -769,6 +760,14 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
# resize mask to match the image latents
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
block_state.mask_latents = torch.nn.functional.interpolate(
|
||||
processed_mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user