diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 26625753e4..5d873baf8f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -405,11 +405,12 @@ class Attention(nn.Module): else: try: # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) except Exception as e: raise e