mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] Fix/pipeline without text encoders for SDXL (#5301)
* fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -168,7 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -317,12 +317,17 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -438,7 +443,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -447,7 +456,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -459,10 +473,15 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -885,7 +904,14 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
|
||||
self,
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -895,7 +921,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1391,6 +1417,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1398,6 +1429,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline(
|
||||
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
|
||||
watermarker is used.
|
||||
"""
|
||||
model_cpu_offload_seq = (
|
||||
"text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet
|
||||
)
|
||||
# leave controlnet out on purpose because it iterates with unet
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -285,12 +285,17 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -406,7 +411,11 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -415,7 +424,12 @@ class StableDiffusionXLControlNetPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -427,10 +441,15 @@ class StableDiffusionXLControlNetPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -706,11 +725,13 @@ class StableDiffusionXLControlNetPipeline(
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1088,8 +1109,17 @@ class StableDiffusionXLControlNetPipeline(
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
@@ -1098,6 +1128,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
@@ -183,7 +183,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -329,12 +329,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -450,7 +455,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -459,7 +468,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -471,10 +485,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -832,6 +851,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -843,7 +863,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1275,6 +1295,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
if negative_target_size is None:
|
||||
negative_target_size = target_size
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1285,6 +1311,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ class StableDiffusionXLPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -167,6 +168,7 @@ class StableDiffusionXLPipeline(
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
@@ -275,12 +277,17 @@ class StableDiffusionXLPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -396,7 +403,11 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -405,7 +416,12 @@ class StableDiffusionXLPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -417,10 +433,15 @@ class StableDiffusionXLPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -533,11 +554,13 @@ class StableDiffusionXLPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -843,8 +866,17 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
@@ -852,6 +884,7 @@ class StableDiffusionXLPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
@@ -143,8 +143,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -282,12 +281,17 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -403,7 +407,11 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -412,7 +420,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -424,10 +437,15 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -618,6 +636,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -629,7 +648,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -983,6 +1002,11 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -993,6 +1017,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -431,12 +431,17 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -552,7 +557,11 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -561,7 +570,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -573,10 +587,15 @@ class StableDiffusionXLInpaintPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -836,6 +855,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -847,7 +867,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1289,6 +1309,11 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1299,6 +1324,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -31,11 +31,13 @@ from ...models.attention_processor import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -150,6 +152,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -280,8 +283,17 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -390,7 +402,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -399,7 +412,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -552,11 +565,13 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -871,8 +886,17 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
@@ -160,6 +160,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -290,12 +291,17 @@ class StableDiffusionXLAdapterPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -411,7 +417,11 @@ class StableDiffusionXLAdapterPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -420,7 +430,12 @@ class StableDiffusionXLAdapterPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -432,10 +447,15 @@ class StableDiffusionXLAdapterPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -550,11 +570,13 @@ class StableDiffusionXLAdapterPipeline(
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -928,8 +950,17 @@ class StableDiffusionXLAdapterPipeline(
|
||||
adapter_state[k] = torch.cat([v] * 2, dim=0)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
@@ -937,6 +968,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
@@ -5,6 +5,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.utils import deprecate
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.activations import get_activation
|
||||
|
||||
@@ -42,6 +42,7 @@ from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +50,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -179,6 +184,9 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
@@ -324,7 +332,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -470,7 +478,7 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
@@ -522,9 +530,12 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
return self._test_save_load_optional_components()
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -646,7 +657,7 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
@@ -702,6 +713,9 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
def test_negative_conditions(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
@@ -35,13 +35,15 @@ from diffusers import (
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionXLPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -114,8 +116,6 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
# "safety_checker": None,
|
||||
# "feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -233,6 +233,9 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
|
||||
@@ -34,13 +34,19 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
assert_mean_pixel_difference,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionXLAdapterPipelineFastTests(
|
||||
PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLAdapterPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
@@ -215,6 +221,9 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
expected_out_image_size,
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
return self._test_save_load_optional_components()
|
||||
|
||||
|
||||
class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
@@ -38,7 +38,7 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -341,7 +341,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
@@ -600,3 +600,6 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@@ -36,14 +36,23 @@ from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
|
||||
@@ -175,3 +184,6 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
|
||||
def test_cfg(self):
|
||||
pass
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@@ -974,6 +974,150 @@ class PipelinePushToHubTester(unittest.TestCase):
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
|
||||
# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
|
||||
# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
|
||||
# test for all such pipelines. This requires us to use a custom `encode_prompt()` function.
|
||||
class SDXLOptionalComponentsTesterMixin:
|
||||
def encode_prompt(
|
||||
self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None
|
||||
):
|
||||
device = text_encoders[0].device
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
else:
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# for classifier-free guidance
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
# for classifier-free guidance
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def _test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
|
||||
tokenizer = components.pop("tokenizer")
|
||||
tokenizer_2 = components.pop("tokenizer_2")
|
||||
text_encoder = components.pop("text_encoder")
|
||||
text_encoder_2 = components.pop("text_encoder_2")
|
||||
|
||||
tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2]
|
||||
text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2]
|
||||
prompt = inputs.pop("prompt")
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(tokenizers, text_encoders, prompt)
|
||||
inputs["prompt_embeds"] = prompt_embeds
|
||||
inputs["negative_prompt_embeds"] = negative_prompt_embeds
|
||||
inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
|
||||
inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for optional_component in pipe._optional_components:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
_ = inputs.pop("prompt")
|
||||
inputs["prompt_embeds"] = prompt_embeds
|
||||
inputs["negative_prompt_embeds"] = negative_prompt_embeds
|
||||
inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
|
||||
inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
|
||||
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
Reference in New Issue
Block a user