From a7bf77fc284810483f1e60afe34d1d27ad91ce2e Mon Sep 17 00:00:00 2001 From: Aleksei Zhuravlev Date: Mon, 20 May 2024 18:14:34 +0100 Subject: [PATCH] Passing `cross_attention_kwargs` to `StableDiffusionInstructPix2PixPipeline` (#7961) * Update pipeline_stable_diffusion_instruct_pix2pix.py Add `cross_attention_kwargs` to `__call__` method of `StableDiffusionInstructPix2PixPipeline`, which are passed to UNet. * Update documentation for pipeline_stable_diffusion_instruct_pix2pix.py * Update docstring * Update docstring * Fix typing import --- .../pipeline_stable_diffusion_instruct_pix2pix.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 1443c8b0af..35166313ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image @@ -180,6 +180,7 @@ class StableDiffusionInstructPix2PixPipeline( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + cross_attention_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): r""" @@ -239,6 +240,9 @@ class StableDiffusionInstructPix2PixPipeline( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: @@ -415,6 +419,7 @@ class StableDiffusionInstructPix2PixPipeline( t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0]