mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ROCm get_device_capability
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
ef38892ed4
commit
0e756434bd
@ -3085,10 +3085,11 @@ def get_device_properties() -> DeviceProperties:
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
import torch
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if IS_ROCM_SYSTEM:
|
||||
major, _ = torch.hip.get_device_capability()
|
||||
return ("rocm", major)
|
||||
else:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return ("cuda", major)
|
||||
elif IS_XPU_SYSTEM:
|
||||
import torch
|
||||
|
Loading…
Reference in New Issue
Block a user