diff --git a/utils/print_env.py b/utils/print_env.py index 443ed6eab6c..04ea99947e0 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -21,6 +21,7 @@ import os import sys import transformers +from transformers import is_torch_xpu_available os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -32,11 +33,21 @@ try: import torch print("Torch version:", torch.__version__) - print("Cuda available:", torch.cuda.is_available()) - print("Cuda version:", torch.version.cuda) - print("CuDNN version:", torch.backends.cudnn.version()) - print("Number of GPUs available:", torch.cuda.device_count()) - print("NCCL version:", torch.cuda.nccl.version()) + accelerator = "NA" + if torch.cuda.is_available(): + accelerator = "CUDA" + elif is_torch_xpu_available(): + accelerator = "XPU" + print("Torch accelerator:", accelerator) + + if accelerator == "CUDA": + print("Cuda version:", torch.version.cuda) + print("CuDNN version:", torch.backends.cudnn.version()) + print("Number of GPUs available:", torch.cuda.device_count()) + print("NCCL version:", torch.cuda.nccl.version()) + elif accelerator == "XPU": + print("SYCL version:", torch.version.xpu) + print("Number of XPUs available:", torch.xpu.device_count()) except ImportError: print("Torch version:", None)