refine transformers env output (#38274)

* refine `transformers env` output

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-05-22 21:22:18 +08:00 committed by GitHub
parent 1234683309
commit dfbee79ca3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,6 +32,7 @@ from ..utils import (
is_torch_available,
is_torch_hpu_available,
is_torch_npu_available,
is_torch_xpu_available,
)
from . import BaseTransformersCLICommand
@ -89,15 +90,25 @@ class EnvironmentCommand(BaseTransformersCLICommand):
pt_version = "not installed"
pt_cuda_available = "NA"
pt_accelerator = "NA"
if is_torch_available():
import torch
pt_version = torch.__version__
pt_cuda_available = torch.cuda.is_available()
pt_xpu_available = torch.xpu.is_available()
pt_xpu_available = is_torch_xpu_available()
pt_npu_available = is_torch_npu_available()
pt_hpu_available = is_torch_hpu_available()
if pt_cuda_available:
pt_accelerator = "CUDA"
elif pt_xpu_available:
pt_accelerator = "XPU"
elif pt_npu_available:
pt_accelerator = "NPU"
elif pt_hpu_available:
pt_accelerator = "HPU"
tf_version = "not installed"
tf_cuda_available = "NA"
if is_tf_available():
@ -141,7 +152,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
"Accelerate version": f"{accelerate_version}",
"Accelerate config": f"{accelerate_config_str}",
"DeepSpeed version": f"{deepspeed_version}",
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
"PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})",
"Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
"Jax version": f"{jax_version}",