Fix ROCm get_device_capability

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyy 2025-04-10 01:36:14 +08:00
parent ef38892ed4
commit 0e756434bd

View File

@ -3085,10 +3085,11 @@ def get_device_properties() -> DeviceProperties:
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
import torch import torch
major, _ = torch.cuda.get_device_capability()
if IS_ROCM_SYSTEM: if IS_ROCM_SYSTEM:
major, _ = torch.hip.get_device_capability()
return ("rocm", major) return ("rocm", major)
else: else:
major, _ = torch.cuda.get_device_capability()
return ("cuda", major) return ("cuda", major)
elif IS_XPU_SYSTEM: elif IS_XPU_SYSTEM:
import torch import torch