From 629a46aaa5efd59edd49b78539f104f3f40958eb Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 19:27:14 +0300 Subject: [PATCH] Disable IPEX attention if the GPU supports 64 bit --- modules/intel/ipex/__init__.py | 5 +++-- modules/intel/ipex/gradscaler.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/intel/ipex/__init__.py b/modules/intel/ipex/__init__.py index cbadd14fc..851bc79c4 100644 --- a/modules/intel/ipex/__init__.py +++ b/modules/intel/ipex/__init__.py @@ -166,8 +166,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() - attention_init() - ipex_diffusers() + if not torch.xpu.has_fp64_dtype(): + attention_init() + ipex_diffusers() except Exception as e: return False, e return True, None diff --git a/modules/intel/ipex/gradscaler.py b/modules/intel/ipex/gradscaler.py index 530212101..6eb56bc2b 100644 --- a/modules/intel/ipex/gradscaler.py +++ b/modules/intel/ipex/gradscaler.py @@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device )