From c8fdfe457229d647d6019e449f3eb6fafb4b6e92 Mon Sep 17 00:00:00 2001 From: Chanchana Sornsoontorn Date: Wed, 19 Apr 2023 23:51:58 +0700 Subject: [PATCH] Correct `Transformer2DModel.forward` docstring (#3074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ⚙️chore(transformer_2d) update function signature for encoder_hidden_states --- src/diffusers/models/transformer_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 23364bfa1d..fde1014bd2 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*):