1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add npu support (#7144)

* Add npu support

* fix for code quality check

* fix for code quality check
This commit is contained in:
Mengqing Cao
2024-03-09 11:12:55 +08:00
committed by GitHub
parent 3f9c746fb2
commit e6ff752840
5 changed files with 32 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ from ..utils import (
is_note_seq_available,
is_onnx_available,
is_torch_available,
is_torch_npu_available,
is_transformers_available,
)

View File

@@ -53,12 +53,19 @@ from ..utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module
if is_torch_npu_available():
import torch_npu # noqa: F401
from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS,

View File

@@ -12,6 +12,7 @@ from .utils import (
convert_state_dict_to_peft,
deprecate,
is_peft_available,
is_torch_npu_available,
is_torchvision_available,
is_transformers_available,
)
@@ -26,6 +27,9 @@ if is_peft_available():
if is_torchvision_available():
from torchvision import transforms
if is_torch_npu_available():
import torch_npu # noqa: F401
def set_seed(seed: int):
"""
@@ -36,8 +40,11 @@ def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
def compute_snr(noise_scheduler, timesteps):

View File

@@ -72,6 +72,7 @@ from .import_utils import (
is_scipy_available,
is_tensorboard_available,
is_torch_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
is_torchsde_available,

View File

@@ -14,6 +14,7 @@
"""
Import utilities: Utilities related to imports and our lazy inits.
"""
import importlib.util
import operator as op
import os
@@ -72,6 +73,15 @@ if _torch_xla_available:
except ImportError:
_torch_xla_available = False
# check whether torch_npu is available
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
try:
_torch_npu_version = importlib_metadata.version("torch_npu")
logger.info(f"torch_npu version {_torch_npu_version} available.")
except ImportError:
_torch_npu_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
@@ -294,6 +304,10 @@ def is_torch_xla_available():
return _torch_xla_available
def is_torch_npu_available():
return _torch_npu_available
def is_flax_available():
return _flax_available