diff --git a/utils/print_env.py b/utils/print_env.py index 0a1cfbef13..2d2acb59d5 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -34,13 +34,24 @@ try: print("Torch version:", torch.__version__) print("Cuda available:", torch.cuda.is_available()) - print("Cuda version:", torch.version.cuda) - print("CuDNN version:", torch.backends.cudnn.version()) - print("Number of GPUs available:", torch.cuda.device_count()) if torch.cuda.is_available(): + print("Cuda version:", torch.version.cuda) + print("CuDNN version:", torch.backends.cudnn.version()) + print("Number of GPUs available:", torch.cuda.device_count()) device_properties = torch.cuda.get_device_properties(0) total_memory = device_properties.total_memory / (1024**3) print(f"CUDA memory: {total_memory} GB") + + print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available()) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + print("XPU model:", torch.xpu.get_device_properties(0).name) + print("XPU compiler version:", torch.version.xpu) + print("Number of XPUs available:", torch.xpu.device_count()) + device_properties = torch.xpu.get_device_properties(0) + total_memory = device_properties.total_memory / (1024**3) + print(f"XPU memory: {total_memory} GB") + + except ImportError: print("Torch version:", None)