1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00

Disable IPEX attention if the GPU supports 64 bit

This commit is contained in:
Disty0
2023-12-05 19:27:14 +03:00
parent 3a36c07c6c
commit 629a46aaa5
2 changed files with 8 additions and 3 deletions

View File

@@ -166,8 +166,9 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks()
attention_init()
ipex_diffusers()
if not torch.xpu.has_fp64_dtype():
attention_init()
ipex_diffusers()
except Exception as e:
return False, e
return True, None

View File

@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)