1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

initial draft update

This commit is contained in:
yiyixuxu
2025-08-05 03:01:36 +02:00
parent 7ea065c507
commit 8946974ccc
2 changed files with 194 additions and 403 deletions

View File

@@ -117,84 +117,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
def prepare_latents_img2img(
vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if vae.config.force_upcast:
image = image.float()
vae.to(dtype=torch.float32)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(vae.encode(image), generator=generator)
if vae.config.force_upcast:
vae.to(dtype)
init_latents = init_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=device, dtype=dtype)
latents_std = latents_std.to(device=device, dtype=dtype)
init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
else:
init_latents = vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
if add_noise:
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
class StableDiffusionXLInputStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@@ -419,8 +341,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
InputParam("denoising_end"),
InputParam("strength", default=0.3),
InputParam("denoising_start"),
# YiYi TODO: do we need num_images_per_prompt here?
InputParam("num_images_per_prompt", default=1),
]
@property
@@ -495,7 +415,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
@@ -512,14 +432,12 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
components,
block_state.num_inference_steps,
block_state.strength,
block_state.device,
device,
denoising_start=block_state.denoising_start
if denoising_value_valid(block_state.denoising_start)
else None,
)
block_state.latent_timestep = block_state.timesteps[:1].repeat(
block_state.batch_size * block_state.num_images_per_prompt
)
block_state.latent_timestep = block_state.timesteps[:1]
if (
block_state.denoising_end is not None
@@ -527,14 +445,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
and block_state.denoising_end > 0
and block_state.denoising_end < 1
):
block_state.discrete_timestep_cutoff = int(
discrete_timestep_cutoff = int(
round(
components.scheduler.config.num_train_timesteps
- (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
)
)
block_state.num_inference_steps = len(
list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps))
)
block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
@@ -596,14 +514,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
and block_state.denoising_end > 0
and block_state.denoising_end < 1
):
block_state.discrete_timestep_cutoff = int(
discrete_timestep_cutoff = int(
round(
components.scheduler.config.num_train_timesteps
- (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
)
)
block_state.num_inference_steps = len(
list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps))
)
block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
@@ -627,7 +545,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"),
InputParam(
@@ -654,7 +571,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
),
@@ -691,209 +607,99 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
),
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
@staticmethod
def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
def prepare_latents_inpaint(
def prepare_latents(
self,
components,
batch_size,
num_channels_latents,
height,
width,
image_latents,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
add_noise=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
num_channels_latents,
int(height) // components.vae_scale_factor,
int(width) // components.vae_scale_factor,
)
batch_size = image_latents.shape[0]
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None and add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if add_noise:
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
elif add_noise:
noise = latents.to(device)
latents = noise * components.scheduler.init_noise_sigma
else:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
outputs = (latents,)
return latents, noise
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def check_inputs(self, image_latents, mask, masked_image_latents):
if image_latents.shape[0] != 1:
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
if mask.shape[0] != 1:
raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}")
if masked_image_latents is not None and masked_image_latents.shape[0] != 1:
raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}")
if latent_timestep is not None and len(latent_timestep.shape) > 0:
raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}")
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
device = components._execution_device
block_state.is_strength_max = block_state.strength == 1.0
# for non-inpainting specific unet, we do not need masked_image_latents
if hasattr(components, "unet") and components.unet is not None:
if components.unet.config.in_channels == 4:
block_state.masked_image_latents = None
block_state.add_noise = True if block_state.denoising_start is None else False
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
components,
block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
image=block_state.image_latents,
timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
)
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image_latents,
block_state.batch_size * block_state.num_images_per_prompt,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.mask = torch.nn.functional.interpolate(
block_state.mask, size=(height_latents, width_latents)
)
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1)
if block_state.masked_image_latents is not None:
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
is_strength_max = block_state.strength == 1.0
add_noise = True if block_state.denoising_start is None else False
block_state.latents, block_state.noise = self.prepare_latents(
components=components,
image_latents=block_state.image_latents,
dtype=dtype,
device=device,
generator=block_state.generator,
timestep=block_state.latent_timestep,
is_strength_max=is_strength_max,
add_noise=add_noise,
)
self.set_block_state(state, block_state)
@@ -906,7 +712,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
]
@@ -917,7 +722,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"),
]
@@ -928,7 +732,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
InputParam("generator"),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
),
@@ -944,7 +747,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
@@ -954,27 +757,52 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
]
def check_inputs(self, image_latents):
if image_latents.shape[0] != 1:
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {image_latents.shape[0]}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
latents = scheduler.add_noise(image_latents, noise, timestep)
return latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
block_state.add_noise = True if block_state.denoising_start is None else False
if block_state.latents is None:
block_state.latents = prepare_latents_img2img(
components.vae,
components.scheduler,
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
device = components._execution_device
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
if block_state.latent_timestep is not None:
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
add_noise = True if block_state.denoising_start is None else False
if add_noise:
block_state.latents = prepare_latents(
block_state.image_latents,
block_state.latent_timestep,
block_state.batch_size,
block_state.num_images_per_prompt,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.add_noise,
components.scheduler,
timestep=block_state.latent_timestep,
dtype=dtype,
device=device,
generator=block_state.generator,
)
else:
block_state.latents = block_state.image_latents
self.set_block_state(state, block_state)
@@ -988,7 +816,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("vae", AutoencoderKL),
]
@property
@@ -1026,15 +853,15 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
]
@staticmethod
def check_inputs(components, block_state):
def check_inputs(components, height, width):
if (
block_state.height is not None
and block_state.height % components.vae_scale_factor != 0
or block_state.width is not None
and block_state.width % components.vae_scale_factor != 0
height is not None
and height % components.vae_scale_factor != 0
or width is not None
and width % components.vae_scale_factor != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {height} and {width}."
)
@staticmethod
@@ -1065,26 +892,27 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.dtype is None:
block_state.dtype = components.vae.dtype
dtype = block_state.dtype
if dtype is None:
dtype = components.unet.dtype if hasattr(components, "unet") else torch.float32
block_state.device = components._execution_device
device = components._execution_device
self.check_inputs(components, block_state)
self.check_inputs(components, block_state.height, block_state.width)
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
height = block_state.height or components.default_sample_size * components.vae_scale_factor
width = block_state.width or components.default_sample_size * components.vae_scale_factor
block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor
block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor
block_state.num_channels_latents = components.num_channels_latents
block_state.latents = self.prepare_latents(
components,
block_state.batch_size * block_state.num_images_per_prompt,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
comp=components,
batch_size=final_batch_size,
num_channels_latents=components.num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=block_state.generator,
latents=block_state.latents,
)
self.set_block_state(state, block_state)
@@ -1103,15 +931,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
return [ComponentSpec("unet", UNet2DConditionModel),]
@property
def description(self) -> str:
@@ -1129,6 +949,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1),
InputParam("aesthetic_score", default=6.0),
InputParam("negative_aesthetic_score", default=2.0),
InputParam("embedded_guidance_scale", default=7.5),
]
@property
@@ -1259,18 +1080,20 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
dtype = block_state.pooled_prompt_embeds.dtype
block_state.vae_scale_factor = components.vae_scale_factor
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * block_state.vae_scale_factor
block_state.width = block_state.width * block_state.vae_scale_factor
# define original_size/negative_original_size/target_size/negative_target_size
# - they are all defaulted to None
_, _, height_latents, width_latents = block_state.latents.shape
height = height_latents * components.vae_scale_factor
width = width_latents * components.vae_scale_factor
block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
block_state.original_size = block_state.original_size or (height, width)
block_state.target_size = block_state.target_size or (height, width)
if block_state.negative_original_size is None:
block_state.negative_original_size = block_state.original_size
@@ -1287,15 +1110,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
block_state.negative_original_size,
block_state.negative_crops_coords_top_left,
block_state.negative_target_size,
dtype=block_state.pooled_prompt_embeds.dtype,
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
dtype=dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
block_state.add_time_ids = block_state.add_time_ids.repeat(
block_state.batch_size * block_state.num_images_per_prompt, 1
).to(device=block_state.device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
block_state.batch_size * block_state.num_images_per_prompt, 1
).to(device=block_state.device)
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
# Optionally get Guidance Scale Embedding for LCM
block_state.timestep_cond = None
@@ -1305,12 +1124,10 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
and components.unet.config.time_cond_proj_dim is not None
):
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
block_state.batch_size * block_state.num_images_per_prompt
)
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
block_state.timestep_cond = self.get_guidance_scale_embedding(
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
).to(device=block_state.device, dtype=block_state.latents.dtype)
).to(device=device, dtype=dtype)
self.set_block_state(state, block_state)
return components, state
@@ -1325,15 +1142,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
return [ComponentSpec("unet", UNet2DConditionModel),]
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -1345,6 +1154,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1),
InputParam("embedded_guidance_scale", default=7.5),
]
@property
@@ -1442,24 +1252,26 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
dtype = block_state.pooled_prompt_embeds.dtype
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
_, _, height_latents, width_latents = block_state.latents.shape
height = height_latents * components.vae_scale_factor
width = width_latents * components.vae_scale_factor
block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
block_state.add_time_ids = self._get_add_time_ids(
components,
block_state.original_size,
block_state.crops_coords_top_left,
block_state.target_size,
block_state.pooled_prompt_embeds.dtype,
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
dtype=dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
block_state.negative_add_time_ids = self._get_add_time_ids(
@@ -1467,18 +1279,14 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
block_state.negative_original_size,
block_state.negative_crops_coords_top_left,
block_state.negative_target_size,
block_state.pooled_prompt_embeds.dtype,
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
dtype=dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
block_state.negative_add_time_ids = block_state.add_time_ids
block_state.add_time_ids = block_state.add_time_ids.repeat(
block_state.batch_size * block_state.num_images_per_prompt, 1
).to(device=block_state.device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
block_state.batch_size * block_state.num_images_per_prompt, 1
).to(device=block_state.device)
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
# Optionally get Guidance Scale Embedding for LCM
block_state.timestep_cond = None
@@ -1488,9 +1296,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
and components.unet.config.time_cond_proj_dim is not None
):
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
block_state.batch_size * block_state.num_images_per_prompt
)
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
block_state.timestep_cond = self.get_guidance_scale_embedding(
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
).to(device=block_state.device, dtype=block_state.latents.dtype)
@@ -1613,14 +1419,18 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# (1) prepare controlnet inputs
block_state.device = components._execution_device
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
controlnet = unwrap_module(components.controlnet)
device = components._execution_device
dtype = components.controlnet.dtype
_, _, height_latents, width_latents = block_state.latents.shape
height = height_latents * components.vae_scale_factor
width = width_latents * components.vae_scale_factor
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
# (1) prepare controlnet inputs
# (1.1)
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(
@@ -1670,12 +1480,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
block_state.control_image = self.prepare_control_image(
components,
image=block_state.control_image,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
width=width,
height=height,
batch_size=final_batch_size,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
device=device,
dtype=dtype,
crops_coords=block_state.crops_coords,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1685,12 +1495,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
control_image = self.prepare_control_image(
components,
image=control_image_,
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
width=width,
height=height,
batch_size=final_batch_size,
num_images_per_prompt=block_state.num_images_per_prompt,
device=block_state.device,
dtype=controlnet.dtype,
device=device,
dtype=dtype,
crops_coords=block_state.crops_coords,
)
@@ -1852,9 +1662,10 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
device = components._execution_device
dtype = block_state.dtype or components.controlnet.dtype
block_state.height, block_state.width = block_state.latents.shape[-2:]
block_state.height = block_state.height * components.vae_scale_factor
block_state.width = block_state.width * components.vae_scale_factor
_, _, height_latents, width_latents = block_state.latents.shape
height = height_latents * components.vae_scale_factor
width = width_latents * components.vae_scale_factor
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(
@@ -1900,15 +1711,15 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
block_state.control_image[idx] = self.prepare_control_image(
components,
image=block_state.control_image[idx],
width=block_state.width,
height=block_state.height,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
width=width,
height=height,
batch_size=final_batch_size,
num_images_per_prompt=block_state.num_images_per_prompt,
device=device,
dtype=dtype,
crops_coords=block_state.crops_coords,
)
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
_, _, height, width = block_state.control_image[idx].shape
# controlnet_keep
block_state.controlnet_keep = []

View File

@@ -813,29 +813,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
return mask, masked_image_latents