1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-08-06 04:57:31 +02:00
parent 8946974ccc
commit dc6a4d4cb4
4 changed files with 329 additions and 423 deletions

View File

@@ -91,7 +91,10 @@ class ComponentSpec:
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
# YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future:
# 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks
# 2. should not need to define default_creation_method for optional components
required: bool = True
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})

View File

@@ -418,21 +418,21 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
block_state.device,
block_state.timesteps,
block_state.sigmas,
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
timesteps=block_state.timesteps,
sigmas=block_state.sigmas,
)
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
block_state.timesteps, block_state.num_inference_steps = self.get_timesteps(
components,
block_state.num_inference_steps,
block_state.strength,
device,
components=components,
num_inference_steps=block_state.num_inference_steps,
strength=block_state.strength,
device=device,
denoising_start=block_state.denoising_start
if denoising_value_valid(block_state.denoising_start)
else None,
@@ -498,14 +498,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
block_state.device,
block_state.timesteps,
block_state.sigmas,
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
timesteps=block_state.timesteps,
sigmas=block_state.sigmas,
)
if (
@@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
),
InputParam(
"mask",
"processed_mask_image",
required=True,
type_hint=torch.Tensor,
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
@@ -591,7 +591,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
type_hint=torch.Tensor,
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
]
@property

View File

@@ -57,6 +57,99 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
def get_clip_prompt_embeds(
prompt,
text_encoder,
tokenizer,
device,
clip_skip=None,
max_length=None,
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length if max_length is not None else tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only using the pooled output of the text_encoder_2, which has 2 dimensions
# (pooled output for text_encoder has 3 dimensions)
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
return prompt_embeds, pooled_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
def encode_vae_image(
image: torch.Tensor,
vae: AutoencoderKL,
generator: torch.Generator,
dtype: torch.dtype,
device: torch.device
):
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
image = image.to(device=device, dtype=dtype)
if vae.config.force_upcast:
image = image.float()
vae.to(dtype=torch.float32)
if isinstance(generator, list) and len(generator) != image.shape[0]:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {image.shape[0]}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
if vae.config.force_upcast:
vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=device, dtype=dtype)
latents_std = latents_std.to(device=device, dtype=dtype)
image_latents = (image_latents - latents_mean) * vae.config.scaling_factor / latents_std
else:
image_latents = vae.config.scaling_factor * image_latents
return image_latents
class StableDiffusionXLIPAdapterStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@@ -86,6 +179,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
required=False,
),
]
@@ -103,12 +197,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=torch.Tensor,
description="Negative IP adapter image embeddings",
),
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"),
]
@staticmethod
@@ -137,79 +227,36 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
components,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
prepare_unconditional_embeds,
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.ip_adapter_embeds = []
if components.requires_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
if not isinstance(block_state.ip_adapter_image, list):
block_state.ip_adapter_image = [block_state.ip_adapter_image]
if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
block_state.ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
block_state.ip_adapter_embeds.append(single_image_embeds[None, :])
if components.requires_unconditional_embeds:
block_state.negative_ip_adapter_embeds.append(single_negative_image_embeds[None, :])
self.set_block_state(state, block_state)
return components, state
@@ -225,15 +272,16 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder", CLIPTextModel, required=False),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer", CLIPTokenizer, required=False),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
required=False,
),
]
@@ -244,7 +292,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt", required=True),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
@@ -282,15 +330,25 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_2 is not None and (
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if negative_prompt_2 is not None and (
not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list)
):
raise ValueError(f"`negative_prompt_2` has to be of type `str` or `list` but is {type(negative_prompt_2)}")
@staticmethod
def encode_prompt(
@@ -298,14 +356,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
requires_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
@@ -331,20 +384,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
@@ -352,31 +391,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
dtype = components.text_encoder_2.dtype
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = len(prompt)
# Define tokenizers and text encoders
tokenizers = (
@@ -389,58 +409,56 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
if components.text_encoder is not None
else [components.text_encoder_2]
)
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
# dynamically adjust the LoRA scale
for text_encoder in text_encoders:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(text_encoder, lora_scale)
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
scale_lora_layers(text_encoder, lora_scale)
# Define prompts
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompts = [prompt, prompt_2]
prompt_embeds_list.append(prompt_embeds)
# generate prompt_embeds & pooled_prompt_embeds
prompt_embeds_list = []
pooled_prompt_embeds_list = []
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=clip_skip,
max_length=tokenizer.model_max_length
)
prompt_embeds_list.append(prompt_embeds)
if pooled_prompt_embeds.ndim == 2:
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds_list, dim=0)
negative_prompt_embeds = None
negative_pooled_prompt_embeds = None
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
# generate negative_prompt_embeds & negative_pooled_prompt_embeds
if requires_unconditional_embeds and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
elif requires_unconditional_embeds:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
@@ -451,87 +469,52 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
if batch_size != len(negative_prompt_2):
raise ValueError(
f"`negative_prompt_2`: {negative_prompt_2} has batch size {len(negative_prompt_2)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt_2` matches"
" the batch size of `prompt`."
)
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
negative_pooled_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=negative_prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=None,
max_length=max_length
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
if negative_pooled_prompt_embeds.ndim == 2:
negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype, device=device)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device)
if requires_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=components.text_encoder_2.dtype, device=device
)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
for text_encoder in text_encoders:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
unscale_lora_layers(text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -539,13 +522,14 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
device = components._execution_device
dtype = components.text_encoder_2.dtype
# Encode input prompt
block_state.text_encoder_lora_scale = (
lora_scale = (
block_state.cross_attention_kwargs.get("scale", None)
if block_state.cross_attention_kwargs is not None
else None
@@ -557,18 +541,13 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
prompt=block_state.prompt,
prompt2=block_state.prompt_2,
device = device,
requires_unconditional_embeds=components.requires_unconditional_embeds,
negative_prompt=block_state.negative_prompt,
negative_prompt_2=block_state.negative_prompt_2,
lora_scale=lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
@@ -599,8 +578,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
@@ -608,11 +585,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
@@ -622,68 +594,30 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
),
OutputParam(
"processed_image",
type_hint=PIL.Image.Image,
description="The preprocessed image",
)
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.batch_size = block_state.image.shape[0]
block_state.processed_image = components.image_processor.preprocess(block_state.image)
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
# Encode image into latents
block_state.image_latents = encode_vae_image(
image=block_state.processed_image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
)
self.set_block_state(state, block_state)
@@ -741,7 +675,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
OutputParam(
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam(
"masked_image_latents",
type_hint=torch.Tensor,
@@ -752,129 +685,89 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
type_hint=Optional[Tuple[int, int]],
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
OutputParam(
"processed_image",
type_hint=PIL.Image.Image,
description="The preprocessed image",
),
OutputParam(
"processed_mask_image",
type_hint=torch.Tensor,
description="The preprocessed mask image",
),
]
def check_inputs(self, image, mask_image, padding_mask_crop):
if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
return mask, masked_image_latents
if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
self.check_inputs(block_state.image, block_state.mask_image, block_state.padding_mask_crop)
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
device = components._execution_device
if block_state.height is None:
block_state.height = components.default_height
height = components.default_height
if block_state.width is None:
block_state.width = components.default_width
width = components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
crops_coords = components.mask_processor.get_crop_region(
mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop
)
block_state.resize_mode = "fill"
resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
crops_coords = None
resize_mode = "default"
block_state.image = components.image_processor.preprocess(
block_state.processed_image = components.image_processor.preprocess(
block_state.image,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode,
height=height,
width=width,
crops_coords=crops_coords,
resize_mode=resize_mode,
)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(
block_state.processed_image = block_state.processed_image.to(dtype=torch.float32)
block_state.processed_mask_image = components.mask_processor.preprocess(
block_state.mask_image,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
block_state.image_latents = encode_vae_image(
image=block_state.processed_image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.masked_image_latents = encode_vae_image(
image=masked_image,
vae=components.vae,
generator=block_state.generator,
dtype=dtype,
device=device
)
self.set_block_state(state, block_state)

View File

@@ -89,6 +89,16 @@ class StableDiffusionXLModularPipeline(
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
@property
def requires_unconditional_embeds(self):
# by default, always prepare unconditional embeddings
requires_unconditional_embeds = True
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider.num_conditions > 1
return requires_unconditional_embeds
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks