mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
initial draft update
This commit is contained in:
@@ -117,84 +117,6 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
||||
):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if vae.config.force_upcast:
|
||||
image = image.float()
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
||||
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
||||
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
if vae.config.force_upcast:
|
||||
vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=dtype)
|
||||
init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
init_latents = vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
if add_noise:
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# get latents
|
||||
init_latents = scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@@ -419,8 +341,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_end"),
|
||||
InputParam("strength", default=0.3),
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -495,7 +415,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
@@ -512,14 +432,12 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
components,
|
||||
block_state.num_inference_steps,
|
||||
block_state.strength,
|
||||
block_state.device,
|
||||
device,
|
||||
denoising_start=block_state.denoising_start
|
||||
if denoising_value_valid(block_state.denoising_start)
|
||||
else None,
|
||||
)
|
||||
block_state.latent_timestep = block_state.timesteps[:1].repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt
|
||||
)
|
||||
block_state.latent_timestep = block_state.timesteps[:1]
|
||||
|
||||
if (
|
||||
block_state.denoising_end is not None
|
||||
@@ -527,14 +445,14 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
and block_state.denoising_end > 0
|
||||
and block_state.denoising_end < 1
|
||||
):
|
||||
block_state.discrete_timestep_cutoff = int(
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
components.scheduler.config.num_train_timesteps
|
||||
- (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
block_state.num_inference_steps = len(
|
||||
list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
|
||||
list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps))
|
||||
)
|
||||
block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
|
||||
|
||||
@@ -596,14 +514,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
and block_state.denoising_end > 0
|
||||
and block_state.denoising_end < 1
|
||||
):
|
||||
block_state.discrete_timestep_cutoff = int(
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
components.scheduler.config.num_train_timesteps
|
||||
- (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
block_state.num_inference_steps = len(
|
||||
list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
|
||||
list(filter(lambda ts: ts >= discrete_timestep_cutoff, block_state.timesteps))
|
||||
)
|
||||
block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
|
||||
|
||||
@@ -627,7 +545,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
InputParam(
|
||||
@@ -654,7 +571,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
@@ -691,209 +607,99 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
),
|
||||
]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
@staticmethod
|
||||
def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
components.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||
|
||||
if components.vae.config.force_upcast:
|
||||
components.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
|
||||
def prepare_latents_inpaint(
|
||||
def prepare_latents(
|
||||
self,
|
||||
components,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
image_latents,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
image=None,
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // components.vae_scale_factor,
|
||||
int(width) // components.vae_scale_factor,
|
||||
)
|
||||
|
||||
batch_size = image_latents.shape[0]
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if (image is None or timestep is None) and not is_strength_max:
|
||||
raise ValueError(
|
||||
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
||||
"However, either the image or the noise timestep has not been provided."
|
||||
)
|
||||
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if latents is None and add_noise:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if add_noise:
|
||||
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
||||
latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
|
||||
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
||||
latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
|
||||
elif add_noise:
|
||||
noise = latents.to(device)
|
||||
latents = noise * components.scheduler.init_noise_sigma
|
||||
|
||||
else:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
return latents, noise
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
|
||||
return outputs
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
||||
# do not accept do_classifier_free_guidance
|
||||
def prepare_mask_latents(
|
||||
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
if masked_image is not None and masked_image.shape[1] == 4:
|
||||
masked_image_latents = masked_image
|
||||
else:
|
||||
masked_image_latents = None
|
||||
|
||||
if masked_image is not None:
|
||||
if masked_image_latents is None:
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
def check_inputs(self, image_latents, mask, masked_image_latents):
|
||||
|
||||
if image_latents.shape[0] != 1:
|
||||
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
|
||||
if mask.shape[0] != 1:
|
||||
raise ValueError(f"mask should have have batch size 1, but got {mask.shape[0]}")
|
||||
if masked_image_latents is not None and masked_image_latents.shape[0] != 1:
|
||||
raise ValueError(f"masked_image_latents should have have batch size 1, but got {masked_image_latents.shape[0]}")
|
||||
|
||||
if latent_timestep is not None and len(latent_timestep.shape) > 0:
|
||||
raise ValueError(f"latent_timestep should be a scalar, but got {latent_timestep.shape}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.image_latents, block_state.mask, block_state.masked_image_latents, block_state.latent_timestep)
|
||||
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
|
||||
device = components._execution_device
|
||||
|
||||
block_state.is_strength_max = block_state.strength == 1.0
|
||||
|
||||
# for non-inpainting specific unet, we do not need masked_image_latents
|
||||
if hasattr(components, "unet") and components.unet is not None:
|
||||
if components.unet.config.in_channels == 4:
|
||||
block_state.masked_image_latents = None
|
||||
|
||||
block_state.add_noise = True if block_state.denoising_start is None else False
|
||||
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
image=block_state.image_latents,
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
_, _, height_latents, width_latents = block_state.image_latents.shape
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
block_state.masked_image_latents,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
block_state.mask, size=(height_latents, width_latents)
|
||||
)
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=dtype)
|
||||
block_state.mask = block_state.mask.repeat(final_batch_size, 1, 1, 1)
|
||||
|
||||
if block_state.masked_image_latents is not None:
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.to(device=device, dtype=dtype)
|
||||
block_state.masked_image_latents = block_state.masked_image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
|
||||
|
||||
is_strength_max = block_state.strength == 1.0
|
||||
add_noise = True if block_state.denoising_start is None else False
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents(
|
||||
components=components,
|
||||
image_latents=block_state.image_latents,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
add_noise=add_noise,
|
||||
)
|
||||
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -906,7 +712,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@@ -917,7 +722,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
@@ -928,7 +732,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
@@ -944,7 +747,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -954,27 +757,52 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
)
|
||||
]
|
||||
|
||||
def check_inputs(self, image_latents):
|
||||
if image_latents.shape[0] != 1:
|
||||
raise ValueError(f"image_latents should have have batch size 1, but got {image_latents.shape[0]}")
|
||||
|
||||
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
|
||||
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {image_latents.shape[0]}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.add_noise(image_latents, noise, timestep)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
block_state.add_noise = True if block_state.denoising_start is None else False
|
||||
if block_state.latents is None:
|
||||
block_state.latents = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
dtype = block_state.dtype if block_state.dtype is not None else block_state.image_latents.dtype
|
||||
device = components._execution_device
|
||||
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
block_state.image_latents = block_state.image_latents.to(device=device, dtype=dtype)
|
||||
block_state.image_latents = block_state.image_latents.repeat(final_batch_size, 1, 1, 1)
|
||||
|
||||
if block_state.latent_timestep is not None:
|
||||
block_state.latent_timestep = block_state.latent_timestep.repeat(final_batch_size)
|
||||
block_state.latent_timestep = block_state.latent_timestep.to(device=device, dtype=dtype)
|
||||
|
||||
add_noise = True if block_state.denoising_start is None else False
|
||||
|
||||
if add_noise:
|
||||
block_state.latents = prepare_latents(
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
block_state.batch_size,
|
||||
block_state.num_images_per_prompt,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.add_noise,
|
||||
components.scheduler,
|
||||
timestep=block_state.latent_timestep,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
)
|
||||
else:
|
||||
block_state.latents = block_state.image_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -988,7 +816,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1026,15 +853,15 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
def check_inputs(components, height, width):
|
||||
if (
|
||||
block_state.height is not None
|
||||
and block_state.height % components.vae_scale_factor != 0
|
||||
or block_state.width is not None
|
||||
and block_state.width % components.vae_scale_factor != 0
|
||||
height is not None
|
||||
and height % components.vae_scale_factor != 0
|
||||
or width is not None
|
||||
and width % components.vae_scale_factor != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {height} and {width}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1065,26 +892,27 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.dtype is None:
|
||||
block_state.dtype = components.vae.dtype
|
||||
dtype = block_state.dtype
|
||||
if dtype is None:
|
||||
dtype = components.unet.dtype if hasattr(components, "unet") else torch.float32
|
||||
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
self.check_inputs(components, block_state.height, block_state.width)
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
height = block_state.height or components.default_sample_size * components.vae_scale_factor
|
||||
width = block_state.width or components.default_sample_size * components.vae_scale_factor
|
||||
|
||||
block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor
|
||||
block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.latents = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
comp=components,
|
||||
batch_size=final_batch_size,
|
||||
num_channels_latents=components.num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
latents=block_state.latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -1103,15 +931,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -1129,6 +949,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
InputParam("embedded_guidance_scale", default=7.5),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1259,18 +1080,20 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
dtype = block_state.pooled_prompt_embeds.dtype
|
||||
|
||||
block_state.vae_scale_factor = components.vae_scale_factor
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * block_state.vae_scale_factor
|
||||
block_state.width = block_state.width * block_state.vae_scale_factor
|
||||
# define original_size/negative_original_size/target_size/negative_target_size
|
||||
# - they are all defaulted to None
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
|
||||
block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
|
||||
block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
|
||||
|
||||
block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
block_state.original_size = block_state.original_size or (height, width)
|
||||
block_state.target_size = block_state.target_size or (height, width)
|
||||
|
||||
if block_state.negative_original_size is None:
|
||||
block_state.negative_original_size = block_state.original_size
|
||||
@@ -1287,15 +1110,11 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
block_state.negative_original_size,
|
||||
block_state.negative_crops_coords_top_left,
|
||||
block_state.negative_target_size,
|
||||
dtype=block_state.pooled_prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, 1
|
||||
).to(device=block_state.device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, 1
|
||||
).to(device=block_state.device)
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
|
||||
# Optionally get Guidance Scale Embedding for LCM
|
||||
block_state.timestep_cond = None
|
||||
@@ -1305,12 +1124,10 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and components.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt
|
||||
)
|
||||
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=block_state.device, dtype=block_state.latents.dtype)
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1325,15 +1142,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("unet", UNet2DConditionModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
return [ComponentSpec("unet", UNet2DConditionModel),]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -1345,6 +1154,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("embedded_guidance_scale", default=7.5),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -1442,24 +1252,26 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
dtype = block_state.pooled_prompt_embeds.dtype
|
||||
text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * components.vae_scale_factor
|
||||
block_state.width = block_state.width * components.vae_scale_factor
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
|
||||
block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
|
||||
|
||||
block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
|
||||
|
||||
block_state.add_time_ids = self._get_add_time_ids(
|
||||
components,
|
||||
block_state.original_size,
|
||||
block_state.crops_coords_top_left,
|
||||
block_state.target_size,
|
||||
block_state.pooled_prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
|
||||
block_state.negative_add_time_ids = self._get_add_time_ids(
|
||||
@@ -1467,18 +1279,14 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
block_state.negative_original_size,
|
||||
block_state.negative_crops_coords_top_left,
|
||||
block_state.negative_target_size,
|
||||
block_state.pooled_prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=block_state.text_encoder_projection_dim,
|
||||
dtype=dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
block_state.negative_add_time_ids = block_state.add_time_ids
|
||||
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, 1
|
||||
).to(device=block_state.device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, 1
|
||||
).to(device=block_state.device)
|
||||
block_state.add_time_ids = block_state.add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(final_batch_size, 1).to(device=device)
|
||||
|
||||
# Optionally get Guidance Scale Embedding for LCM
|
||||
block_state.timestep_cond = None
|
||||
@@ -1488,9 +1296,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
and components.unet.config.time_cond_proj_dim is not None
|
||||
):
|
||||
# TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
|
||||
block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
|
||||
block_state.batch_size * block_state.num_images_per_prompt
|
||||
)
|
||||
block_state.guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
|
||||
block_state.timestep_cond = self.get_guidance_scale_embedding(
|
||||
block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
|
||||
).to(device=block_state.device, dtype=block_state.latents.dtype)
|
||||
@@ -1613,14 +1419,18 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
block_state.device = components._execution_device
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * components.vae_scale_factor
|
||||
block_state.width = block_state.width * components.vae_scale_factor
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.controlnet.dtype
|
||||
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# (1) prepare controlnet inputs
|
||||
|
||||
# (1.1)
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(
|
||||
@@ -1670,12 +1480,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
block_state.control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=final_batch_size,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
@@ -1685,12 +1495,12 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
control_image = self.prepare_control_image(
|
||||
components,
|
||||
image=control_image_,
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=final_batch_size,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=block_state.device,
|
||||
dtype=controlnet.dtype,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
|
||||
@@ -1852,9 +1662,10 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype or components.controlnet.dtype
|
||||
|
||||
block_state.height, block_state.width = block_state.latents.shape[-2:]
|
||||
block_state.height = block_state.height * components.vae_scale_factor
|
||||
block_state.width = block_state.width * components.vae_scale_factor
|
||||
_, _, height_latents, width_latents = block_state.latents.shape
|
||||
height = height_latents * components.vae_scale_factor
|
||||
width = width_latents * components.vae_scale_factor
|
||||
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(
|
||||
@@ -1900,15 +1711,15 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
block_state.control_image[idx] = self.prepare_control_image(
|
||||
components,
|
||||
image=block_state.control_image[idx],
|
||||
width=block_state.width,
|
||||
height=block_state.height,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=final_batch_size,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
|
||||
_, _, height, width = block_state.control_image[idx].shape
|
||||
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
|
||||
@@ -813,29 +813,9 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
if masked_image is not None and masked_image.shape[1] == 4:
|
||||
masked_image_latents = masked_image
|
||||
else:
|
||||
masked_image_latents = None
|
||||
|
||||
if masked_image is not None:
|
||||
if masked_image_latents is None:
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user