From e6ff75284025323f199313d9e076796c6ff8dd8a Mon Sep 17 00:00:00 2001 From: Mengqing Cao <52243582+MengqingCao@users.noreply.github.com> Date: Sat, 9 Mar 2024 11:12:55 +0800 Subject: [PATCH] Add npu support (#7144) * Add npu support * fix for code quality check * fix for code quality check --- src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/pipeline_utils.py | 7 +++++++ src/diffusers/training_utils.py | 11 +++++++++-- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 14 ++++++++++++++ 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d7aacc4d67..94e8d227f7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -11,6 +11,7 @@ from ..utils import ( is_note_seq_available, is_onnx_available, is_torch_available, + is_torch_npu_available, is_transformers_available, ) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e4ac56f9d0..fa706ea57d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -53,12 +53,19 @@ from ..utils import ( deprecate, is_accelerate_available, is_accelerate_version, + is_torch_npu_available, is_torch_version, logging, numpy_to_pil, ) from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module + + +if is_torch_npu_available(): + import torch_npu # noqa: F401 + + from .pipeline_loading_utils import ( ALL_IMPORTABLE_CLASSES, CONNECTED_PIPES_KEYS, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 32a1738e9d..25e02a3d14 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -12,6 +12,7 @@ from .utils import ( convert_state_dict_to_peft, deprecate, is_peft_available, + is_torch_npu_available, is_torchvision_available, is_transformers_available, ) @@ -26,6 +27,9 @@ if is_peft_available(): if is_torchvision_available(): from torchvision import transforms +if is_torch_npu_available(): + import torch_npu # noqa: F401 + def set_seed(seed: int): """ @@ -36,8 +40,11 @@ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available + if is_torch_npu_available(): + torch.npu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available def compute_snr(noise_scheduler, timesteps): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 35aba10d7e..4e2f07f2ba 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -72,6 +72,7 @@ from .import_utils import ( is_scipy_available, is_tensorboard_available, is_torch_available, + is_torch_npu_available, is_torch_version, is_torch_xla_available, is_torchsde_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 9c916737d1..a3ee31c91c 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -14,6 +14,7 @@ """ Import utilities: Utilities related to imports and our lazy inits. """ + import importlib.util import operator as op import os @@ -72,6 +73,15 @@ if _torch_xla_available: except ImportError: _torch_xla_available = False +# check whether torch_npu is available +_torch_npu_available = importlib.util.find_spec("torch_npu") is not None +if _torch_npu_available: + try: + _torch_npu_version = importlib_metadata.version("torch_npu") + logger.info(f"torch_npu version {_torch_npu_version} available.") + except ImportError: + _torch_npu_available = False + _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: @@ -294,6 +304,10 @@ def is_torch_xla_available(): return _torch_xla_available +def is_torch_npu_available(): + return _torch_npu_available + + def is_flax_available(): return _flax_available