1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/dml/device.py
2024-05-29 12:59:16 +09:00

18 lines
496 B
Python

from typing import Optional
import torch
from .utils import rDevice, get_device
class Device:
idx: int
def __enter__(self, device: Optional[rDevice]=None):
torch.dml.context_device = get_device(device)
self.idx = torch.dml.context_device.index
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
self.idx = get_device(device).index
def __exit__(self, t, v, tb):
torch.dml.context_device = None