mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add npu support (#7144)
* Add npu support * fix for code quality check * fix for code quality check
This commit is contained in:
@@ -11,6 +11,7 @@ from ..utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,12 +53,19 @@ from ..utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
|
||||
from .pipeline_loading_utils import (
|
||||
ALL_IMPORTABLE_CLASSES,
|
||||
CONNECTED_PIPES_KEYS,
|
||||
|
||||
@@ -12,6 +12,7 @@ from .utils import (
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
is_peft_available,
|
||||
is_torch_npu_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -26,6 +27,9 @@ if is_peft_available():
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
@@ -36,8 +40,11 @@ def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed_all(seed)
|
||||
else:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
|
||||
|
||||
def compute_snr(noise_scheduler, timesteps):
|
||||
|
||||
@@ -72,6 +72,7 @@ from .import_utils import (
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
is_torchsde_available,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import operator as op
|
||||
import os
|
||||
@@ -72,6 +73,15 @@ if _torch_xla_available:
|
||||
except ImportError:
|
||||
_torch_xla_available = False
|
||||
|
||||
# check whether torch_npu is available
|
||||
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
|
||||
if _torch_npu_available:
|
||||
try:
|
||||
_torch_npu_version = importlib_metadata.version("torch_npu")
|
||||
logger.info(f"torch_npu version {_torch_npu_version} available.")
|
||||
except ImportError:
|
||||
_torch_npu_available = False
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
@@ -294,6 +304,10 @@ def is_torch_xla_available():
|
||||
return _torch_xla_available
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
return _torch_npu_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user