From 1ae9b0595f28df9cc92df87cf49193ec8ca07245 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:45:49 -0800 Subject: [PATCH] Fix enable memory efficient attention on ROCm (#10564) * fix enable memory efficient attention on ROCm while calling CK implementation * Update attention_processor.py refactor of picking a set element --- src/diffusers/models/attention_processor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 26625753e4..5d873baf8f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -405,11 +405,12 @@ class Attention(nn.Module): else: try: # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) except Exception as e: raise e