fix: Updated the is_torch_mps_available() function to include min_version argument (#32545)

* Fixed wrong argument in is_torch_mps_available() function call.

* Fixed wrong argument in is_torch_mps_available() function call.

* sorted the import.

* Fixed wrong argument in is_torch_mps_available() function call.

* Fixed wrong argument in is_torch_mps_available() function call.

* Update src/transformers/utils/import_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* removed extra space.

* Added type hint for the min_version parameter.

* Added missing import.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sai-Suraj-27 2024-08-13 01:12:57 +05:30 committed by GitHub
parent f1c8542ff7
commit 2a5a6ad18a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,7 @@ from collections import OrderedDict
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
from typing import Any, Optional, Tuple, Union
from packaging import version
@ -420,12 +420,16 @@ def is_mambapy_available():
return False
def is_torch_mps_available():
def is_torch_mps_available(min_version: Optional[str] = None):
if is_torch_available():
import torch
if hasattr(torch.backends, "mps"):
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built()
if min_version is not None:
flag = version.parse(_torch_version) >= version.parse(min_version)
backend_available = backend_available and flag
return backend_available
return False