From debc74f442dc74210528eb6d8a4d1f7f27fa18c3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 28 Dec 2022 18:49:04 +0100 Subject: [PATCH] [Versatile Diffusion] Fix cross_attention_kwargs (#1849) fix versatile --- src/diffusers/models/attention.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 91c450d4a5..b6f5158e51 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -703,7 +703,13 @@ class DualTransformer2DModel(nn.Module): self.transformer_index_for_condition = [1, 0] def forward( - self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, ): """ Args: @@ -738,6 +744,7 @@ class DualTransformer2DModel(nn.Module): input_states, encoder_hidden_states=condition_state, timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states)