mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix: Updated the is_torch_mps_available()
function to include min_version
argument (#32545)
* Fixed wrong argument in is_torch_mps_available() function call. * Fixed wrong argument in is_torch_mps_available() function call. * sorted the import. * Fixed wrong argument in is_torch_mps_available() function call. * Fixed wrong argument in is_torch_mps_available() function call. * Update src/transformers/utils/import_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * removed extra space. * Added type hint for the min_version parameter. * Added missing import. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
f1c8542ff7
commit
2a5a6ad18a
@ -27,7 +27,7 @@ from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any, Tuple, Union
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -420,12 +420,16 @@ def is_mambapy_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_mps_available():
|
||||
def is_torch_mps_available(min_version: Optional[str] = None):
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if hasattr(torch.backends, "mps"):
|
||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
if min_version is not None:
|
||||
flag = version.parse(_torch_version) >= version.parse(min_version)
|
||||
backend_available = backend_available and flag
|
||||
return backend_available
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user