diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 29d153ba04..2fa4d3b047 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -431,6 +431,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, @@ -481,6 +483,16 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa f" {negative_prompt_embeds.shape}." ) + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -668,6 +680,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -738,6 +752,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -809,6 +830,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_2, prompt_embeds, negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, @@ -854,6 +877,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 2f76be9269..e9d64c89c6 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -258,3 +258,42 @@ class ControlNetPipelineSDXLFastTests( # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # copied from test_stable_diffusion_xl.py + def test_stable_diffusion_xl_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 2 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + inputs = self.get_dummy_inputs(torch_device) + prompt = 2 * [inputs.pop("prompt")] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4