mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
more
This commit is contained in:
@@ -91,7 +91,10 @@ class ComponentSpec:
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
# YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future:
|
||||
# 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks
|
||||
# 2. should not need to define default_creation_method for optional components
|
||||
required: bool = True
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
|
||||
@@ -418,21 +418,21 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
device = components._execution_device
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
block_state.device,
|
||||
block_state.timesteps,
|
||||
block_state.sigmas,
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
timesteps=block_state.timesteps,
|
||||
sigmas=block_state.sigmas,
|
||||
)
|
||||
|
||||
def denoising_value_valid(dnv):
|
||||
return isinstance(dnv, float) and 0 < dnv < 1
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = self.get_timesteps(
|
||||
components,
|
||||
block_state.num_inference_steps,
|
||||
block_state.strength,
|
||||
device,
|
||||
components=components,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
strength=block_state.strength,
|
||||
device=device,
|
||||
denoising_start=block_state.denoising_start
|
||||
if denoising_value_valid(block_state.denoising_start)
|
||||
else None,
|
||||
@@ -498,14 +498,14 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
block_state.device,
|
||||
block_state.timesteps,
|
||||
block_state.sigmas,
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
timesteps=block_state.timesteps,
|
||||
sigmas=block_state.sigmas,
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -581,7 +581,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"mask",
|
||||
"processed_mask_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
|
||||
@@ -591,7 +591,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
type_hint=torch.Tensor,
|
||||
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@@ -57,6 +57,99 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def get_clip_prompt_embeds(
|
||||
prompt,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
device,
|
||||
clip_skip=None,
|
||||
max_length=None,
|
||||
):
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length if max_length is not None else tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only using the pooled output of the text_encoder_2, which has 2 dimensions
|
||||
# (pooled output for text_encoder has 3 dimensions)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
def encode_vae_image(
|
||||
image: torch.Tensor,
|
||||
vae: AutoencoderKL,
|
||||
generator: torch.Generator,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device
|
||||
):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if vae.config.force_upcast:
|
||||
image = image.float()
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != image.shape[0]:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {image.shape[0]}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
if vae.config.force_upcast:
|
||||
vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@@ -86,6 +179,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -103,12 +197,8 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
||||
OutputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative IP adapter image embeddings",
|
||||
),
|
||||
OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="IP adapter image embeddings"),
|
||||
OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative IP adapter image embeddings"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -137,79 +227,36 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self,
|
||||
components,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
prepare_unconditional_embeds,
|
||||
):
|
||||
image_embeds = []
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
components, single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
if prepare_unconditional_embeds:
|
||||
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
||||
else:
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
negative_image_embeds.append(single_negative_image_embeds)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if prepare_unconditional_embeds:
|
||||
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
||||
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
|
||||
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
components,
|
||||
ip_adapter_image=block_state.ip_adapter_image,
|
||||
ip_adapter_image_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1,
|
||||
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
|
||||
)
|
||||
if block_state.prepare_unconditional_embeds:
|
||||
block_state.ip_adapter_embeds = []
|
||||
if components.requires_unconditional_embeds:
|
||||
block_state.negative_ip_adapter_embeds = []
|
||||
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
|
||||
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
||||
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
|
||||
block_state.ip_adapter_embeds[i] = image_embeds
|
||||
|
||||
if not isinstance(block_state.ip_adapter_image, list):
|
||||
block_state.ip_adapter_image = [block_state.ip_adapter_image]
|
||||
|
||||
if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
block_state.ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
components, single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
|
||||
block_state.ip_adapter_embeds.append(single_image_embeds[None, :])
|
||||
if components.requires_unconditional_embeds:
|
||||
block_state.negative_ip_adapter_embeds.append(single_negative_image_embeds[None, :])
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -225,15 +272,16 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", CLIPTextModel),
|
||||
ComponentSpec("text_encoder", CLIPTextModel, required=False),
|
||||
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer),
|
||||
ComponentSpec("tokenizer", CLIPTokenizer, required=False),
|
||||
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -244,7 +292,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt", required=True),
|
||||
InputParam("prompt_2"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("negative_prompt_2"),
|
||||
@@ -282,15 +330,25 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2):
|
||||
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_2 is not None and (
|
||||
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
elif block_state.prompt_2 is not None and (
|
||||
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if negative_prompt_2 is not None and (
|
||||
not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt_2` has to be of type `str` or `list` but is {type(negative_prompt_2)}")
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
@@ -298,14 +356,9 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
prompt: str,
|
||||
prompt_2: Optional[str] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
requires_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
@@ -331,20 +384,6 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
@@ -352,31 +391,12 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
|
||||
components._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if components.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(components.text_encoder, lora_scale)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(components.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = len(prompt)
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = (
|
||||
@@ -389,58 +409,56 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
if components.text_encoder is not None
|
||||
else [components.text_encoder_2]
|
||||
)
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
|
||||
components._lora_scale = lora_scale
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt_2]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
# dynamically adjust the LoRA scale
|
||||
for text_encoder in text_encoders:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(text_encoder, lora_scale)
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
scale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
# Define prompts
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
prompts = [prompt, prompt_2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
# generate prompt_embeds & pooled_prompt_embeds
|
||||
prompt_embeds_list = []
|
||||
pooled_prompt_embeds_list = []
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
clip_skip=clip_skip,
|
||||
max_length=tokenizer.model_max_length
|
||||
)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
if pooled_prompt_embeds.ndim == 2:
|
||||
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds_list, dim=0)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
negative_pooled_prompt_embeds = None
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
|
||||
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
# generate negative_prompt_embeds & negative_pooled_prompt_embeds
|
||||
if requires_unconditional_embeds and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||
elif requires_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
|
||||
@@ -451,87 +469,52 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
)
|
||||
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
if batch_size != len(negative_prompt_2):
|
||||
raise ValueError(
|
||||
f"`negative_prompt_2`: {negative_prompt_2} has batch size {len(negative_prompt_2)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt_2` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
negative_pooled_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
||||
if isinstance(components, TextualInversionLoaderMixin):
|
||||
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
clip_skip=None,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
if negative_pooled_prompt_embeds.ndim == 2:
|
||||
negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(dtype, device=device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device)
|
||||
if requires_unconditional_embeds:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
dtype=components.text_encoder_2.dtype, device=device
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
if prepare_unconditional_embeds:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if components.text_encoder is not None:
|
||||
for text_encoder in text_encoders:
|
||||
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(components.text_encoder, lora_scale)
|
||||
|
||||
if components.text_encoder_2 is not None:
|
||||
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
||||
unscale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -539,13 +522,14 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# Get inputs and intermediates
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
self.check_inputs(block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
device = components._execution_device
|
||||
dtype = components.text_encoder_2.dtype
|
||||
|
||||
# Encode input prompt
|
||||
block_state.text_encoder_lora_scale = (
|
||||
lora_scale = (
|
||||
block_state.cross_attention_kwargs.get("scale", None)
|
||||
if block_state.cross_attention_kwargs is not None
|
||||
else None
|
||||
@@ -557,18 +541,13 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
block_state.negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
components,
|
||||
block_state.prompt,
|
||||
block_state.prompt_2,
|
||||
block_state.device,
|
||||
1,
|
||||
block_state.prepare_unconditional_embeds,
|
||||
block_state.negative_prompt,
|
||||
block_state.negative_prompt_2,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
prompt=block_state.prompt,
|
||||
prompt2=block_state.prompt_2,
|
||||
device = device,
|
||||
requires_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
negative_prompt_2=block_state.negative_prompt_2,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=block_state.clip_skip,
|
||||
)
|
||||
# Add outputs
|
||||
@@ -599,8 +578,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -608,11 +585,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -622,68 +594,30 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
description="The preprocessed image",
|
||||
)
|
||||
]
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
components.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||
|
||||
if components.vae.config.force_upcast:
|
||||
components.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
device = components._execution_device
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.processed_image = components.image_processor.preprocess(block_state.image)
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
# Encode image into latents
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=block_state.processed_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -741,7 +675,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
OutputParam(
|
||||
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
|
||||
),
|
||||
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
||||
OutputParam(
|
||||
"masked_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -752,129 +685,89 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
description="The preprocessed image",
|
||||
),
|
||||
OutputParam(
|
||||
"processed_mask_image",
|
||||
type_hint=torch.Tensor,
|
||||
description="The preprocessed mask image",
|
||||
),
|
||||
]
|
||||
|
||||
def check_inputs(self, image, mask_image, padding_mask_crop):
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
)
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
||||
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
||||
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
dtype = image.dtype
|
||||
if components.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
components.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
||||
|
||||
if components.vae.config.force_upcast:
|
||||
components.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
||||
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
image_latents = components.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
||||
# do not accept do_classifier_free_guidance
|
||||
def prepare_mask_latents(
|
||||
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
self.check_inputs(block_state.image, block_state.mask_image, block_state.padding_mask_crop)
|
||||
|
||||
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
device = components._execution_device
|
||||
|
||||
if block_state.height is None:
|
||||
block_state.height = components.default_height
|
||||
height = components.default_height
|
||||
if block_state.width is None:
|
||||
block_state.width = components.default_width
|
||||
width = components.default_width
|
||||
|
||||
if block_state.padding_mask_crop is not None:
|
||||
block_state.crops_coords = components.mask_processor.get_crop_region(
|
||||
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
|
||||
crops_coords = components.mask_processor.get_crop_region(
|
||||
mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop
|
||||
)
|
||||
block_state.resize_mode = "fill"
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
height=height,
|
||||
width=width,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
block_state.processed_image = block_state.processed_image.to(dtype=torch.float32)
|
||||
|
||||
block_state.processed_mask_image = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
height=height,
|
||||
width=width,
|
||||
resize_mode=resize_mode,
|
||||
crops_coords=crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
|
||||
masked_image = block_state.processed_image * (block_state.processed_mask_image < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image=block_state.processed_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.masked_image_latents = encode_vae_image(
|
||||
image=masked_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -89,6 +89,16 @@ class StableDiffusionXLModularPipeline(
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
# by default, always prepare unconditional embeddings
|
||||
requires_unconditional_embeds = True
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
|
||||
|
||||
Reference in New Issue
Block a user