1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/dml/hijack/utils.py
2024-02-05 22:43:10 +09:00

20 lines
504 B
Python

import torch
from typing import Callable
from modules.shared import log, opts
def catch_nan(func: Callable[[], torch.Tensor]):
if not opts.directml_catch_nan:
return func()
tries = 0
tensor = func()
while tensor.isnan().sum() != 0 and tries < 10:
if tries == 0:
log.warning("NaN is produced. Retry with same values...")
tries += 1
tensor = func()
if tensor.isnan().sum() != 0:
log.error("Failed to cover NaN.")
return tensor