From 5939ace91bd090a2593b76f57707c104da427c51 Mon Sep 17 00:00:00 2001 From: Leo Jiang Date: Tue, 27 May 2025 07:59:15 -0600 Subject: [PATCH] Adding NPU for get device function (#11617) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding device choice for npu * Adding device choice for npu --------- Co-authored-by: J石页 Co-authored-by: Sayak Paul --- src/diffusers/utils/torch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 053a3d99b9..bb5674092d 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -18,7 +18,7 @@ PyTorch utilities: Utilities related to PyTorch from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available, is_torch_version +from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version if is_torch_available(): @@ -166,6 +166,8 @@ def get_torch_cuda_device_capability(): def get_device(): if torch.cuda.is_available(): return "cuda" + elif is_torch_npu_available(): + return "npu" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" else: