1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Passing cross_attention_kwargs to StableDiffusionInstructPix2PixPipeline (#7961)

* Update pipeline_stable_diffusion_instruct_pix2pix.py

Add `cross_attention_kwargs` to `__call__` method of `StableDiffusionInstructPix2PixPipeline`, which are passed to UNet.

* Update documentation for pipeline_stable_diffusion_instruct_pix2pix.py

* Update docstring

* Update docstring

* Fix typing import
This commit is contained in:
Aleksei Zhuravlev
2024-05-20 18:14:34 +01:00
committed by GitHub
parent 0f0defdb65
commit a7bf77fc28

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
@@ -180,6 +180,7 @@ class StableDiffusionInstructPix2PixPipeline(
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
r"""
@@ -239,6 +240,9 @@ class StableDiffusionInstructPix2PixPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
@@ -415,6 +419,7 @@ class StableDiffusionInstructPix2PixPipeline(
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]