1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)

* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
chenxiao
2025-07-11 05:51:05 +08:00
committed by GitHub
parent 76a62ac9cc
commit 941b7fc084

View File

@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)