From 3768d4d77ced5142430c30108ccb0dcd52f9d151 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 30 Aug 2023 08:57:26 +0530 Subject: [PATCH] [Core] refactor encode_prompt (#4617) * refactoring of encode_prompt() * better handling of device. * fix: device determination * fix: device determination 2 * handle num_images_per_prompt * revert changes in loaders.py and give birth to encode_prompt(). * minor refactoring for encode_prompt()/ * make backward compatible. * Apply suggestions from code review Co-authored-by: Patrick von Platen * fix: concatenation of the neg and pos embeddings. * incorporate encode_prompt() in test_stable_diffusion.py * turn it into big PR. * make it bigger * gligen fixes. * more fixes to fligen * _encode_prompt -> encode_prompt in tests * first batch * second batch * fix blasphemous mistake * fix * fix: hopefully for the final time. --------- Co-authored-by: Patrick von Platen --- .../alt_diffusion/pipeline_alt_diffusion.py | 49 +++++- .../pipeline_alt_diffusion_img2img.py | 49 +++++- .../controlnet/pipeline_controlnet.py | 48 +++++- .../controlnet/pipeline_controlnet_img2img.py | 48 +++++- .../controlnet/pipeline_controlnet_inpaint.py | 47 +++++- .../pipeline_cycle_diffusion.py | 52 +++++- .../pipeline_stable_diffusion.py | 46 ++++- ...line_stable_diffusion_attend_and_excite.py | 49 +++++- .../pipeline_stable_diffusion_depth2img.py | 47 +++++- .../pipeline_stable_diffusion_diffedit.py | 65 +++++-- .../pipeline_stable_diffusion_gligen.py | 48 +++++- .../pipeline_stable_diffusion_img2img.py | 47 +++++- .../pipeline_stable_diffusion_inpaint.py | 47 +++++- ...ipeline_stable_diffusion_inpaint_legacy.py | 47 +++++- .../pipeline_stable_diffusion_k_diffusion.py | 49 +++++- .../pipeline_stable_diffusion_ldm3d.py | 48 +++++- ...pipeline_stable_diffusion_model_editing.py | 49 +++++- .../pipeline_stable_diffusion_panorama.py | 49 +++++- .../pipeline_stable_diffusion_paradigms.py | 48 +++++- .../pipeline_stable_diffusion_pix2pix_zero.py | 54 +++++- .../pipeline_stable_diffusion_sag.py | 49 +++++- .../pipeline_stable_diffusion_upscale.py | 47 +++++- .../pipeline_stable_unclip.py | 56 +++++- .../pipeline_stable_unclip_img2img.py | 159 +++++++++++------- .../pipeline_stable_diffusion_adapter.py | 48 +++++- .../pipeline_text_to_video_synth.py | 48 +++++- .../pipeline_text_to_video_synth_img2img.py | 48 +++++- .../stable_diffusion/test_stable_diffusion.py | 12 +- .../test_stable_diffusion.py | 12 +- .../test_stable_diffusion_upscale.py | 5 +- 30 files changed, 1180 insertions(+), 290 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 6aabf17f26..8cf3085884 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -258,12 +258,45 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = ( + "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`" + " instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + ) + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -402,12 +435,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -634,7 +662,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -644,6 +672,11 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 64be253d3b..89c8279fb3 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -259,12 +259,45 @@ class AltDiffusionImg2ImgPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = ( + "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`" + " instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + ) + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -403,12 +436,7 @@ class AltDiffusionImg2ImgPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -668,7 +696,7 @@ class AltDiffusionImg2ImgPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -678,6 +706,11 @@ class AltDiffusionImg2ImgPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 734c8d025f..84e89b5049 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -28,6 +28,7 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, is_compiled_module, @@ -250,12 +251,43 @@ class StableDiffusionControlNetPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -394,12 +426,7 @@ class StableDiffusionControlNetPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -842,7 +869,7 @@ class StableDiffusionControlNetPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -852,6 +879,11 @@ class StableDiffusionControlNetPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index ec58304600..7c2db0d8c7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -276,12 +276,43 @@ class StableDiffusionControlNetImg2ImgPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -420,12 +451,7 @@ class StableDiffusionControlNetImg2ImgPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -921,7 +947,7 @@ class StableDiffusionControlNetImg2ImgPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -931,6 +957,12 @@ class StableDiffusionControlNetImg2ImgPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + # 4. Prepare image image = self.image_processor.preprocess(image).to(dtype=torch.float32) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 2719b385f2..271f996e82 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -402,12 +402,43 @@ class StableDiffusionControlNetInpaintPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -546,12 +577,7 @@ class StableDiffusionControlNetInpaintPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1126,7 +1152,7 @@ class StableDiffusionControlNetInpaintPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -1136,6 +1162,11 @@ class StableDiffusionControlNetInpaintPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 7ce4734b13..ceeaeb6494 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -270,12 +270,43 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -414,12 +445,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( @@ -738,7 +764,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds_tuple = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -746,9 +772,17 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor prompt_embeds=prompt_embeds, lora_scale=text_encoder_lora_scale, ) - source_prompt_embeds = self._encode_prompt( + source_prompt_embeds_tuple = self.encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None ) + if prompt_embeds_tuple[1] is not None: + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + else: + prompt_embeds = prompt_embeds_tuple[0] + if source_prompt_embeds_tuple[1] is not None: + source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]]) + else: + source_prompt_embeds = source_prompt_embeds_tuple[0] # 4. Preprocess image image = self.image_processor.preprocess(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9bc2ad57fd..fd65582237 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -260,12 +260,42 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -404,12 +434,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -634,7 +659,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -644,6 +669,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 9b3d2d9ead..0dd9132725 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor, replace_example_docstring +from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -259,12 +259,43 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -403,12 +434,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -810,7 +836,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -819,6 +845,11 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 1713050364..f4d352b895 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -144,12 +144,43 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -288,12 +319,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -631,7 +657,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -641,6 +667,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare depth mask depth_mask = self.prepare_depth_map( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 643b0994c0..c26c967100 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -445,12 +445,43 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -589,12 +620,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -970,7 +996,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM # 3. Encode input prompts (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) - target_prompt_embeds = self._encode_prompt( + target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( target_prompt, device, num_maps_per_mask, @@ -979,8 +1005,13 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM prompt_embeds=target_prompt_embeds, negative_prompt_embeds=target_negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds]) - source_prompt_embeds = self._encode_prompt( + source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( source_prompt, device, num_maps_per_mask, @@ -989,6 +1020,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM prompt_embeds=source_prompt_embeds, negative_prompt_embeds=source_negative_prompt_embeds, ) + if do_classifier_free_guidance: + source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) @@ -1178,7 +1211,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ) # 5. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -1187,6 +1220,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 6. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) @@ -1410,7 +1448,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -1420,6 +1458,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess mask mask_image = preprocess_mask(mask_image, batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index 46c6d14c3e..8293d4c59d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -26,6 +26,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -234,12 +235,43 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -378,12 +410,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -654,7 +681,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -663,6 +690,11 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e3133cc107..69d040ad59 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -264,12 +264,43 @@ class StableDiffusionImg2ImgPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -408,12 +439,7 @@ class StableDiffusionImg2ImgPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -674,7 +700,7 @@ class StableDiffusionImg2ImgPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -684,6 +710,11 @@ class StableDiffusionImg2ImgPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 5c9d8d8514..cef7114611 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -330,12 +330,43 @@ class StableDiffusionInpaintPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -474,12 +505,7 @@ class StableDiffusionInpaintPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -851,7 +877,7 @@ class StableDiffusionInpaintPipeline( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -861,6 +887,11 @@ class StableDiffusionInpaintPipeline( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index a25a753da5..b0688d50a0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -260,12 +260,43 @@ class StableDiffusionInpaintPipelineLegacy( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -404,12 +435,7 @@ class StableDiffusionInpaintPipelineLegacy( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -643,7 +669,7 @@ class StableDiffusionInpaintPipelineLegacy( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -653,6 +679,11 @@ class StableDiffusionInpaintPipelineLegacy( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image and mask if not isinstance(image, torch.FloatTensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 6a4718188f..650fb647de 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -24,7 +24,7 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...schedulers import LMSDiscreteScheduler -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor +from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -167,12 +167,43 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -311,12 +342,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -523,7 +549,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade raise ValueError("has to use guidance_scale") # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -532,6 +558,11 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 202691bfdb..326830f0f9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -27,6 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -229,12 +230,43 @@ class StableDiffusionLDM3DPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -373,12 +405,7 @@ class StableDiffusionLDM3DPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -585,7 +612,7 @@ class StableDiffusionLDM3DPipeline( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -594,6 +621,11 @@ class StableDiffusionLDM3DPipeline( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index ae1d7bad67..1e9a6a9dfa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -24,7 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin -from ...utils import logging, randn_tensor +from ...utils import deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -171,12 +171,43 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -315,12 +346,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -679,7 +705,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -689,6 +715,11 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index cdde5efdba..84e3e209ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -23,7 +23,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler -from ...utils import logging, randn_tensor, replace_example_docstring +from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -148,12 +148,43 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -292,12 +323,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -566,7 +592,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -576,6 +602,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index c3b49839dd..cbbc6c1977 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -23,6 +23,7 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -213,12 +214,43 @@ class StableDiffusionParadigmsPipeline( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -357,12 +389,7 @@ class StableDiffusionParadigmsPipeline( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -590,7 +617,7 @@ class StableDiffusionParadigmsPipeline( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -599,6 +626,11 @@ class StableDiffusionParadigmsPipeline( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index a67bb3e87e..b7f1eb19a7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -403,12 +403,43 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -547,12 +578,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -907,7 +933,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -916,6 +942,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1168,13 +1199,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 5. Encode input prompt num_images_per_prompt = 1 - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 1b0ae8442b..af82899130 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor, replace_example_docstring +from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -171,12 +171,43 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -315,12 +346,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -549,7 +575,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) do_self_attention_guidance = sag_scale > 0.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -558,6 +584,11 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 8f0776361f..a0cadaaa52 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -192,12 +192,43 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -336,12 +367,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -629,7 +655,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -639,6 +665,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 96ac85639d..c7b916c387 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -295,12 +302,43 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -439,12 +477,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): @@ -819,7 +852,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, @@ -829,6 +862,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 9. Prepare image embeddings image_embeds = self.noise_image_embeddings( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index ff79dbc13d..99a279d809 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -196,12 +196,98 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def _encode_image( + self, + image, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level, + generator, + image_embeds, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if isinstance(image, PIL.Image.Image): + # the image embedding should repeated so it matches the total batch size of the prompt + repeat_by = batch_size + else: + # assume the image input is already properly batched and just needs to be repeated so + # it matches the num_images_per_prompt. + # + # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched + # `image_embeds`. If those happen to be common use cases, let's think harder about + # what the expected dimensions of inputs should be and how we handle the encoding. + repeat_by = num_images_per_prompt + + if image_embeds is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeds = image_embeds.unsqueeze(1) + bs_embed, seq_len, _ = image_embeds.shape + image_embeds = image_embeds.repeat(1, repeat_by, 1) + image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) + image_embeds = image_embeds.squeeze(1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -340,67 +426,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def _encode_image( - self, - image, - device, - batch_size, - num_images_per_prompt, - do_classifier_free_guidance, - noise_level, - generator, - image_embeds, - ): - dtype = next(self.image_encoder.parameters()).dtype - - if isinstance(image, PIL.Image.Image): - # the image embedding should repeated so it matches the total batch size of the prompt - repeat_by = batch_size - else: - # assume the image input is already properly batched and just needs to be repeated so - # it matches the num_images_per_prompt. - # - # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched - # `image_embeds`. If those happen to be common use cases, let's think harder about - # what the expected dimensions of inputs should be and how we handle the encoding. - repeat_by = num_images_per_prompt - - if image_embeds is None: - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - - image_embeds = self.noise_image_embeddings( - image_embeds=image_embeds, - noise_level=noise_level, - generator=generator, - ) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - image_embeds = image_embeds.unsqueeze(1) - bs_embed, seq_len, _ = image_embeds.shape - image_embeds = image_embeds.repeat(1, repeat_by, 1) - image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) - image_embeds = image_embeds.squeeze(1) - - if do_classifier_free_guidance: - negative_prompt_embeds = torch.zeros_like(image_embeds) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) - - return image_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): @@ -718,7 +744,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, @@ -728,6 +754,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Encoder input image noise_level = torch.tensor([noise_level], device=device) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 1ee6f9296d..ccb9215b8e 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, BaseOutput, + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -255,12 +256,43 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -399,12 +431,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -692,7 +719,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -701,6 +728,11 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index aefd858e77..5315329e69 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -23,6 +23,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -182,12 +183,43 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -326,12 +358,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -561,7 +588,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -571,6 +598,11 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 7da80b2342..5122f114f3 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -244,12 +245,43 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -388,12 +420,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): @@ -640,7 +667,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -650,6 +677,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video video = preprocess_video(video) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index cf7ef2208f..5bd8f00591 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -584,21 +584,27 @@ class StableDiffusionPipelineFastTests( prompt = 25 * "@" with CaptureLogger(logger) as cap_logger_3: - text_embeddings_3 = sd_pipe._encode_prompt( + negative_text_embeddings_3, text_embeddings_3 = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negative_text_embeddings_3 is not None: + text_embeddings_3 = torch.cat([negative_text_embeddings_3, text_embeddings_3]) prompt = 100 * "@" with CaptureLogger(logger) as cap_logger: - text_embeddings = sd_pipe._encode_prompt( + negative_text_embeddings, text_embeddings = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negative_text_embeddings is not None: + text_embeddings = torch.cat([negative_text_embeddings, text_embeddings]) negative_prompt = "Hello" with CaptureLogger(logger) as cap_logger_2: - text_embeddings_2 = sd_pipe._encode_prompt( + negative_text_embeddings_2, text_embeddings_2 = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negative_text_embeddings_2 is not None: + text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2]) assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape assert text_embeddings.shape[1] == 77 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index fd012a23da..2346492a95 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -245,21 +245,27 @@ class StableDiffusion2PipelineFastTests( prompt = 25 * "@" with CaptureLogger(logger) as cap_logger_3: - text_embeddings_3 = sd_pipe._encode_prompt( + text_embeddings_3, negeative_text_embeddings_3 = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negeative_text_embeddings_3 is not None: + text_embeddings_3 = torch.cat([negeative_text_embeddings_3, text_embeddings_3]) prompt = 100 * "@" with CaptureLogger(logger) as cap_logger: - text_embeddings = sd_pipe._encode_prompt( + text_embeddings, negative_embeddings = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negative_embeddings is not None: + text_embeddings = torch.cat([negative_embeddings, text_embeddings]) negative_prompt = "Hello" with CaptureLogger(logger) as cap_logger_2: - text_embeddings_2 = sd_pipe._encode_prompt( + text_embeddings_2, negative_text_embeddings_2 = sd_pipe.encode_prompt( prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + if negative_text_embeddings_2 is not None: + text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2]) assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape assert text_embeddings.shape[1] == 77 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index c143bda959..ab7eb2e0fd 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -251,7 +251,10 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase): image = output.images generator = torch.Generator(device=device).manual_seed(0) - prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False) + prompt_embeds, negative_prompt_embeds = sd_pipe.encode_prompt(prompt, device, 1, False) + if negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + image_from_prompt_embeds = sd_pipe( prompt_embeds=prompt_embeds, image=[low_res_image],