mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix ROCm get_device_capability
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
ef38892ed4
commit
0e756434bd
@ -3085,10 +3085,11 @@ def get_device_properties() -> DeviceProperties:
|
|||||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
|
||||||
if IS_ROCM_SYSTEM:
|
if IS_ROCM_SYSTEM:
|
||||||
|
major, _ = torch.hip.get_device_capability()
|
||||||
return ("rocm", major)
|
return ("rocm", major)
|
||||||
else:
|
else:
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
return ("cuda", major)
|
return ("cuda", major)
|
||||||
elif IS_XPU_SYSTEM:
|
elif IS_XPU_SYSTEM:
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
Reference in New Issue
Block a user