mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up up up
This commit is contained in:
@@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"mask_latents",
|
||||
"mask",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
|
||||
@@ -607,11 +607,10 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
),
|
||||
]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
|
||||
def prepare_latents(
|
||||
self,
|
||||
components,
|
||||
image_latents,
|
||||
scheduler,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
@@ -631,9 +630,9 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if add_noise:
|
||||
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
||||
latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
|
||||
latents = noise if is_strength_max else scheduler.add_noise(image_latents, noise, timestep)
|
||||
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
||||
latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
|
||||
latents = latents * scheduler.init_noise_sigma if is_strength_max else latents
|
||||
|
||||
else:
|
||||
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -643,13 +642,13 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
|
||||
|
||||
def check_inputs(self, batch_size, image_latents, mask_latents, masked_image_latents):
|
||||
def check_inputs(self, batch_size, image_latents, mask, masked_image_latents):
|
||||
|
||||
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
|
||||
|
||||
if not (mask_latents.shape[0] == 1 or mask_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"mask_latents should have have batch size 1 or {batch_size}, but got {mask_latents.shape[0]}")
|
||||
if not (mask.shape[0] == 1 or mask.shape[0] == batch_size):
|
||||
raise ValueError(f"mask should have have batch size 1 or {batch_size}, but got {mask.shape[0]}")
|
||||
|
||||
if not (masked_image_latents.shape[0] == 1 or masked_image_latents.shape[0] == batch_size):
|
||||
raise ValueError(f"masked_image_latents should have have batch size 1 or {batch_size}, but got {masked_image_latents.shape[0]}")
|
||||
@@ -662,7 +661,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
self.check_inputs(
|
||||
batch_size=block_state.batch_size,
|
||||
image_latents=block_state.image_latents,
|
||||
mask_latents=block_state.mask_latents,
|
||||
mask=block_state.mask,
|
||||
masked_image_latents=block_state.masked_image_latents,
|
||||
)
|
||||
|
||||
@@ -675,8 +674,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size//block_state.image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask_latents = block_state.mask_latents.to(device=device, dtype=dtype)
|
||||
block_state.mask_latents = block_state.mask_latents.repeat(final_batch_size//block_state.mask_latents.shape[0], 1, 1, 1)
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
|
||||
block_state.mask = block_state.mask.repeat(final_batch_size//block_state.mask.shape[0], 1, 1, 1)
|
||||
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size//block_state.masked_image_latents.shape[0], 1, 1, 1)
|
||||
@@ -689,8 +688,8 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
add_noise = True if block_state.denoising_start is None else False
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents(
|
||||
components=components,
|
||||
image_latents=block_state.image_latents,
|
||||
scheduler=components.scheduler,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
@@ -945,15 +944,14 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("original_size"),
|
||||
InputParam("target_size"),
|
||||
InputParam("negative_original_size"),
|
||||
InputParam("target_size"),
|
||||
InputParam("negative_target_size"),
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
InputParam("embedded_guidance_scale", default=7.5),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1050,42 +1048,12 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
return add_time_ids, add_neg_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
w (`torch.Tensor`):
|
||||
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
Dimension of the embeddings to generate.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
Data type of the generated embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = block_state.pooled_prompt_embeds.dtype
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
@@ -1120,19 +1088,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
|
||||
# Optionally get Guidance Scale Embedding for LCM
|
||||
block_state.timestep_cond = None
|
||||
if (
|
||||
hasattr(components, "unet")
|
||||
and components.unet is not None
|
||||
and components.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -1158,7 +1113,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("embedded_guidance_scale", default=7.5),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1199,7 +1153,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
kwargs_type="guider_input_fields",
|
||||
description="The negative time ids to condition the denoising process",
|
||||
),
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -1222,6 +1175,84 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.pooled_prompt_embeds.dtype
|
||||
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
original_size = block_state.original_size or (block_state.height, block_state.width)
|
||||
target_size = block_state.target_size or (block_state.height, block_state.width)
|
||||
|
||||
|
||||
block_state.add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
original_size,
|
||||
block_state.crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
|
||||
block_state.negative_add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
block_state.negative_original_size,
|
||||
block_state.negative_crops_coords_top_left,
|
||||
block_state.negative_target_size,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
block_state.negative_add_time_ids = block_state.add_time_ids
|
||||
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLLCMStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the timestep cond input for latent consistency models"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("embedded_guidance_scale"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
|
||||
]
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
||||
@@ -1253,57 +1284,33 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def check_input(self, unet, embedded_guidance_scale):
|
||||
|
||||
if embedded_guidance_scale is not None and unet.config.time_cond_proj_dim is None:
|
||||
raise ValueError(f"cannot use `embedded_guidance_scale` {embedded_guidance_scale} because unet.config.time_cond_proj_dim is None")
|
||||
|
||||
if embedded_guidance_scale is None and unet.config.time_cond_proj_dim is not None:
|
||||
raise ValueError(f"unet.config.time_cond_proj_dim is not None, but `embedded_guidance_scale` is None")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = block_state.pooled_prompt_embeds.dtype
|
||||
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.unet.dtype
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
|
||||
block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
|
||||
|
||||
|
||||
block_state.add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
block_state.original_size,
|
||||
block_state.crops_coords_top_left,
|
||||
block_state.target_size,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
|
||||
block_state.negative_add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
block_state.negative_original_size,
|
||||
block_state.negative_crops_coords_top_left,
|
||||
block_state.negative_target_size,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
block_state.negative_add_time_ids = block_state.add_time_ids
|
||||
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
|
||||
# Optionally get Guidance Scale Embedding for LCM
|
||||
block_state.timestep_cond = None
|
||||
if (
|
||||
hasattr(components, "unet")
|
||||
and components.unet is not None
|
||||
and components.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=block_state.device, dtype=block_state.latents.dtype)
|
||||
|
||||
guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1459,29 +1466,29 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
)
|
||||
|
||||
# (1.2)
|
||||
# controlnet_conditioning_scale (align format)
|
||||
# conditioning_scale (align format)
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
||||
block_state.controlnet_conditioning_scale, float
|
||||
):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
|
||||
block_state.conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
|
||||
controlnet.nets
|
||||
)
|
||||
else:
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
# (1.3)
|
||||
# global_pool_conditions
|
||||
block_state.global_pool_conditions = (
|
||||
# guess_mode
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
# (1.4)
|
||||
# guess_mode
|
||||
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
|
||||
block_state.guess_mode = block_state.guess_mode or global_pool_conditions
|
||||
|
||||
# (1.5)
|
||||
# control_image
|
||||
# (1.4)
|
||||
# controlnet_cond
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
block_state.control_image = self.prepare_control_image(
|
||||
block_state.controlnet_cond = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image,
|
||||
width=width,
|
||||
@@ -1510,7 +1517,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
|
||||
control_images.append(control_image)
|
||||
|
||||
block_state.control_image = control_images
|
||||
block_state.controlnet_cond = control_images
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -1524,9 +1531,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
block_state.controlnet_cond = block_state.control_image
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
@@ -1686,8 +1690,8 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
]
|
||||
|
||||
# guess_mode
|
||||
block_state.global_pool_conditions = controlnet.config.global_pool_conditions
|
||||
block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
|
||||
global_pool_conditions = controlnet.config.global_pool_conditions
|
||||
block_state.guess_mode = block_state.guess_mode or global_pool_conditions
|
||||
|
||||
# control_image
|
||||
if not isinstance(block_state.control_image, list):
|
||||
@@ -1700,19 +1704,20 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
raise ValueError("Expected len(control_image) == len(control_type)")
|
||||
|
||||
# control_type
|
||||
block_state.num_control_type = controlnet.config.num_control_type
|
||||
block_state.control_type = [0 for _ in range(block_state.num_control_type)]
|
||||
num_control_type = controlnet.config.num_control_type
|
||||
block_state.control_type = [0 for _ in range(num_control_type)]
|
||||
for control_idx in block_state.control_mode:
|
||||
block_state.control_type[control_idx] = 1
|
||||
block_state.control_type = torch.Tensor(block_state.control_type)
|
||||
|
||||
block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype)
|
||||
block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=dtype)
|
||||
repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0]
|
||||
block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
# prepare control_image
|
||||
# prepare controlnet_cond
|
||||
block_state.controlnet_cond = []
|
||||
for idx, _ in enumerate(block_state.control_image):
|
||||
block_state.control_image[idx] = self.prepare_control_image(
|
||||
control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image[idx],
|
||||
width=width,
|
||||
@@ -1723,7 +1728,8 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
dtype=dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
_, _, height, width = block_state.control_image[idx].shape
|
||||
_, _, height, width = control_image.shape
|
||||
block_state.controlnet_cond.append(control_image)
|
||||
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
@@ -1736,7 +1742,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
)
|
||||
)
|
||||
block_state.control_type_idx = block_state.control_mode
|
||||
block_state.controlnet_cond = block_state.control_image
|
||||
block_state.conditioning_scale = block_state.controlnet_conditioning_scale
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -105,9 +105,9 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
if not block_state.output_type == "latent":
|
||||
latents = block_state.latents
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
|
||||
needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
|
||||
|
||||
if block_state.needs_upcasting:
|
||||
if needs_upcasting:
|
||||
self.upcast_vae(components)
|
||||
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
|
||||
elif latents.dtype != components.vae.dtype:
|
||||
@@ -117,21 +117,21 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
block_state.has_latents_mean = (
|
||||
has_latents_mean = (
|
||||
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
|
||||
)
|
||||
block_state.has_latents_std = (
|
||||
has_latents_std = (
|
||||
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
|
||||
)
|
||||
if block_state.has_latents_mean and block_state.has_latents_std:
|
||||
block_state.latents_mean = (
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
block_state.latents_std = (
|
||||
latents_std = (
|
||||
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = (
|
||||
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
|
||||
latents * latents_std / components.vae.config.scaling_factor + latents_mean
|
||||
)
|
||||
else:
|
||||
latents = latents / components.vae.config.scaling_factor
|
||||
|
||||
@@ -67,7 +67,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
|
||||
return components, block_state
|
||||
|
||||
@@ -134,10 +134,10 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
|
||||
if components.num_channels_unet == 9:
|
||||
block_state.scaled_latents = torch.cat(
|
||||
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1
|
||||
)
|
||||
|
||||
return components, block_state
|
||||
@@ -232,7 +232,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.unet(
|
||||
block_state.scaled_latents,
|
||||
block_state.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=block_state.timestep_cond,
|
||||
@@ -410,7 +410,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
mid_block_res_sample = block_state.mid_block_res_sample_zeros
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = components.controlnet(
|
||||
block_state.scaled_latents,
|
||||
block_state.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
controlnet_cond=block_state.controlnet_cond,
|
||||
@@ -430,7 +430,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
# Predict the noise
|
||||
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.unet(
|
||||
block_state.scaled_latents,
|
||||
block_state.latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=guider_state_batch.prompt_embeds,
|
||||
timestep_cond=block_state.timestep_cond,
|
||||
|
||||
@@ -390,7 +390,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
|
||||
@@ -526,7 +525,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
# Encode input prompt
|
||||
lora_scale = (
|
||||
@@ -542,8 +540,8 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
) = self.encode_prompt(
|
||||
components,
|
||||
prompt=block_state.prompt,
|
||||
prompt2=block_state.prompt_2,
|
||||
device = device,
|
||||
prompt_2=block_state.prompt_2,
|
||||
device=device,
|
||||
requires_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
negative_prompt_2=block_state.negative_prompt_2,
|
||||
@@ -604,11 +602,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(block_state.image)
|
||||
image = components.image_processor.preprocess(block_state.image)
|
||||
|
||||
# Encode image into latents
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=block_state.processed_image,
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
@@ -681,7 +679,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
||||
),
|
||||
OutputParam(
|
||||
"mask_latents",
|
||||
"mask",
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to apply on the latents for the inpainting generation.",
|
||||
),
|
||||
@@ -715,15 +713,15 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
width = components.default_width
|
||||
|
||||
if block_state.padding_mask_crop is not None:
|
||||
crops_coords = components.mask_processor.get_crop_region(
|
||||
block_state.crops_coords = components.mask_processor.get_crop_region(
|
||||
mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop
|
||||
)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
block_state.crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
processed_image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -731,9 +729,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
processed_image = processed_image.to(dtype=torch.float32)
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
processed_mask_image = components.mask_processor.preprocess(
|
||||
mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -741,11 +739,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
crops_coords=crops_coords,
|
||||
)
|
||||
|
||||
masked_image = processed_image * (block_state.mask_latents < 0.5)
|
||||
masked_image = image * (block_state.mask_latents < 0.5)
|
||||
|
||||
# Prepare image latent variables
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=processed_image,
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
@@ -763,11 +761,11 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
|
||||
# resize mask to match the image latents
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
block_state.mask_latents = torch.nn.functional.interpolate(
|
||||
processed_mask_image,
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
mask,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
block_state.mask_latents = block_state.mask_latents.to(dtype=dtype, device=device)
|
||||
block_state.mask = block_state.mask.to(dtype=dtype, device=device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from .before_denoise import (
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLLCMStep,
|
||||
)
|
||||
from .decoders import (
|
||||
StableDiffusionXLDecodeStep,
|
||||
@@ -79,6 +80,16 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
|
||||
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
|
||||
|
||||
|
||||
class StableDiffusionXLAutoLCMStep(AutoPipelineBlocks):
|
||||
block_classes = [StableDiffusionXLLCMStep]
|
||||
block_names = ["lcm"]
|
||||
block_trigger_inputs = ["embedded_guidance_scale"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Run LCM step if `latents` is provided. This step should be placed before the 'input' step.\n"
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
@@ -262,6 +273,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
StableDiffusionXLAutoBeforeDenoiseStep,
|
||||
StableDiffusionXLAutoLCMStep,
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoDenoiseStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
@@ -271,6 +283,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
"ip_adapter",
|
||||
"image_encoder",
|
||||
"before_denoise",
|
||||
"lcm",
|
||||
"controlnet_input",
|
||||
"denoise",
|
||||
"decoder",
|
||||
@@ -286,6 +299,7 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
|
||||
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
+ "- to run the latent consistency models workflow, you need to provide `embedded_guidance_scale`"
|
||||
)
|
||||
|
||||
|
||||
@@ -357,6 +371,13 @@ IP_ADAPTER_BLOCKS = InsertableDict(
|
||||
]
|
||||
)
|
||||
|
||||
LCM_BLOCKS = InsertableDict(
|
||||
|
||||
[
|
||||
("lcm", StableDiffusionXLAutoLCMStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
@@ -376,5 +397,6 @@ ALL_BLOCKS = {
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
"controlnet": CONTROLNET_BLOCKS,
|
||||
"ip_adapter": IP_ADAPTER_BLOCKS,
|
||||
"lcm": LCM_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
|
||||
@@ -95,7 +95,10 @@ class StableDiffusionXLModularPipeline(
|
||||
# by default, always prepare unconditional embeddings
|
||||
requires_unconditional_embeds = True
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None:
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
elif hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
Reference in New Issue
Block a user