From 8b3d2aeaf8ed1489752a9dc4ebf69e72c7af6bf0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 17 Oct 2023 11:17:06 +0530 Subject: [PATCH] [Core] Fix/pipeline without text encoders for SDXL (#5301) * fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen --- .../pipeline_controlnet_inpaint_sd_xl.py | 62 ++++++-- .../controlnet/pipeline_controlnet_sd_xl.py | 67 +++++--- .../pipeline_controlnet_sd_xl_img2img.py | 55 +++++-- .../pipeline_stable_diffusion_xl.py | 63 ++++++-- .../pipeline_stable_diffusion_xl_img2img.py | 55 +++++-- .../pipeline_stable_diffusion_xl_inpaint.py | 54 +++++-- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 38 ++++- .../pipeline_stable_diffusion_xl_adapter.py | 62 ++++++-- .../versatile_diffusion/modeling_text_unet.py | 2 + .../controlnet/test_controlnet_sdxl.py | 24 ++- .../test_stable_diffusion_xl.py | 11 +- .../test_stable_diffusion_xl_adapter.py | 13 +- .../test_stable_diffusion_xl_img2img.py | 7 +- ...stable_diffusion_xl_instruction_pix2pix.py | 16 +- tests/pipelines/test_pipelines_common.py | 144 ++++++++++++++++++ 15 files changed, 545 insertions(+), 128 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 4418ede74b..cf51fbe571 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -168,7 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline( [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -317,12 +317,17 @@ class StableDiffusionXLControlNetInpaintPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -438,7 +443,11 @@ class StableDiffusionXLControlNetInpaintPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -447,7 +456,12 @@ class StableDiffusionXLControlNetInpaintPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -459,10 +473,15 @@ class StableDiffusionXLControlNetInpaintPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -885,7 +904,14 @@ class StableDiffusionXLControlNetInpaintPipeline( return timesteps, num_inference_steps - t_start def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -895,7 +921,7 @@ class StableDiffusionXLControlNetInpaintPipeline( add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1391,6 +1417,11 @@ class StableDiffusionXLControlNetInpaintPipeline( # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1398,6 +1429,7 @@ class StableDiffusionXLControlNetInpaintPipeline( aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index f634f3f389..5957366586 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline( watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ - model_cpu_offload_seq = ( - "text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet - ) + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -285,12 +285,17 @@ class StableDiffusionXLControlNetPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -406,7 +411,11 @@ class StableDiffusionXLControlNetPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -415,7 +424,12 @@ class StableDiffusionXLControlNetPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -427,10 +441,15 @@ class StableDiffusionXLControlNetPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -706,11 +725,13 @@ class StableDiffusionXLControlNetPipeline( return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1088,8 +1109,17 @@ class StableDiffusionXLControlNetPipeline( target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: @@ -1098,6 +1128,7 @@ class StableDiffusionXLControlNetPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 3375855ba8..033544e893 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -183,7 +183,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -329,12 +329,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -450,7 +455,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -459,7 +468,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -471,10 +485,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -832,6 +851,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -843,7 +863,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1275,6 +1295,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline( if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1285,6 +1311,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 55bf929a2e..2658b58de2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -140,6 +140,7 @@ class StableDiffusionXLPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -167,6 +168,7 @@ class StableDiffusionXLPipeline( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -275,12 +277,17 @@ class StableDiffusionXLPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -396,7 +403,11 @@ class StableDiffusionXLPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -405,7 +416,12 @@ class StableDiffusionXLPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -417,10 +433,15 @@ class StableDiffusionXLPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -533,11 +554,13 @@ class StableDiffusionXLPipeline( latents = latents * self.scheduler.init_noise_sigma return latents - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -843,8 +866,17 @@ class StableDiffusionXLPipeline( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -852,6 +884,7 @@ class StableDiffusionXLPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b436f404d5..75eb02a486 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -143,8 +143,7 @@ class StableDiffusionXLImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -282,12 +281,17 @@ class StableDiffusionXLImg2ImgPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -403,7 +407,11 @@ class StableDiffusionXLImg2ImgPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -412,7 +420,12 @@ class StableDiffusionXLImg2ImgPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -424,10 +437,15 @@ class StableDiffusionXLImg2ImgPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -618,6 +636,7 @@ class StableDiffusionXLImg2ImgPipeline( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -629,7 +648,7 @@ class StableDiffusionXLImg2ImgPipeline( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -983,6 +1002,11 @@ class StableDiffusionXLImg2ImgPipeline( negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -993,6 +1017,7 @@ class StableDiffusionXLImg2ImgPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c04d2c0518..4af25afbeb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -290,7 +290,7 @@ class StableDiffusionXLInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -431,12 +431,17 @@ class StableDiffusionXLInpaintPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -552,7 +557,11 @@ class StableDiffusionXLInpaintPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -561,7 +570,12 @@ class StableDiffusionXLInpaintPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -573,10 +587,15 @@ class StableDiffusionXLInpaintPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -836,6 +855,7 @@ class StableDiffusionXLInpaintPipeline( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -847,7 +867,7 @@ class StableDiffusionXLInpaintPipeline( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1289,6 +1309,11 @@ class StableDiffusionXLInpaintPipeline( negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1299,6 +1324,7 @@ class StableDiffusionXLInpaintPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 8cd7f46e63..0427214f83 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -31,11 +31,13 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -150,6 +152,7 @@ class StableDiffusionXLInstructPix2PixPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -280,8 +283,17 @@ class StableDiffusionXLInstructPix2PixPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -390,7 +402,8 @@ class StableDiffusionXLInstructPix2PixPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -399,7 +412,7 @@ class StableDiffusionXLInstructPix2PixPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -552,11 +565,13 @@ class StableDiffusionXLInstructPix2PixPipeline( return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -871,8 +886,17 @@ class StableDiffusionXLInstructPix2PixPipeline( # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index d4272696c2..b606b9b50c 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -160,6 +160,7 @@ class StableDiffusionXLAdapterPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, @@ -290,12 +291,17 @@ class StableDiffusionXLAdapterPipeline( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -411,7 +417,11 @@ class StableDiffusionXLAdapterPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -420,7 +430,12 @@ class StableDiffusionXLAdapterPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -432,10 +447,15 @@ class StableDiffusionXLAdapterPipeline( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -550,11 +570,13 @@ class StableDiffusionXLAdapterPipeline( return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -928,8 +950,17 @@ class StableDiffusionXLAdapterPipeline( adapter_state[k] = torch.cat([v] * 2, dim=0) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -937,6 +968,7 @@ class StableDiffusionXLAdapterPipeline( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index a70903b4bd..717db3bbdb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from diffusers.utils import deprecate + from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 4fff88434b..be786ebe30 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -42,6 +42,7 @@ from ..test_pipelines_common import ( PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, ) @@ -49,7 +50,11 @@ enable_full_determinism() class StableDiffusionXLControlNetPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -179,6 +184,9 @@ class StableDiffusionXLControlNetPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] @@ -324,7 +332,7 @@ class StableDiffusionXLControlNetPipelineFastTests( class StableDiffusionXLMultiControlNetPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -470,7 +478,7 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } @@ -522,9 +530,12 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( - PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -646,7 +657,7 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } @@ -702,6 +713,9 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + def test_negative_conditions(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index cebd860a43..4906670890 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -35,13 +35,15 @@ from diffusers import ( from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() -class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -114,8 +116,6 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - # "safety_checker": None, - # "feature_extractor": None, } return components @@ -233,6 +233,9 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 92c22ca2c3..0e7a13bc87 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -34,13 +34,19 @@ from diffusers.utils import logging from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() -class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLAdapterPipelineFastTests( + PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLAdapterPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS @@ -215,6 +221,9 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te expected_out_image_size, ) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index ba7d3e8be3..97c1910894 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -38,7 +38,7 @@ from ..pipeline_params import ( TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() @@ -341,7 +341,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} @@ -600,3 +600,6 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py index ca4017d11b..e20f8a0b54 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py @@ -36,14 +36,23 @@ from ..pipeline_params import ( TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) enable_full_determinism() class StableDiffusionXLInstructPix2PixPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} @@ -175,3 +184,6 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( def test_cfg(self): pass + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 6f2674a7b8..ae13d0d3e9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -974,6 +974,150 @@ class PipelinePushToHubTester(unittest.TestCase): delete_repo(self.org_repo_id, token=TOKEN) +# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders +# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()` +# test for all such pipelines. This requires us to use a custom `encode_prompt()` function. +class SDXLOptionalComponentsTesterMixin: + def encode_prompt( + self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None + ): + device = text_encoders[0].device + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + if negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + else: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # for classifier-free guidance + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # for classifier-free guidance + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _test_save_load_optional_components(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] + prompt = inputs.pop("prompt") + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt(tokenizers, text_encoders, prompt) + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + _ = inputs.pop("prompt") + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.