From 941b7fc0843139e52419a65b7fa850169fde0360 Mon Sep 17 00:00:00 2001 From: chenxiao <154797505+chenxiao111222@users.noreply.github.com> Date: Fri, 11 Jul 2025 05:51:05 +0800 Subject: [PATCH] Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763) * Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) * up --------- Co-authored-by: yiyixuxu --- .../models/transformers/transformer_cosmos.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 3a6cb1ce6e..373b470ae3 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0: key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) # 4. Prepare for GQA - query_idx = torch.tensor(query.size(3), device=query.device) - key_idx = torch.tensor(key.size(3), device=key.device) - value_idx = torch.tensor(value.size(3), device=value.device) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) key = key.repeat_interleave(query_idx // key_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)