mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix logic order for USE_TF/USE_TORCH
This commit is contained in:
parent
5664327c24
commit
faef6f6191
@ -29,25 +29,27 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
try:
|
||||
os.environ.setdefault('USE_TF', 'YES')
|
||||
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
|
||||
logger.info("USE_TF override through env variable, disabling Tensorflow")
|
||||
_tf_available = False
|
||||
else:
|
||||
import tensorflow as tf
|
||||
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
else:
|
||||
logger.info("USE_TF override through env variable, disabling Tensorflow")
|
||||
_tf_available = False
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
os.environ.setdefault('USE_TORCH', 'YES')
|
||||
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
|
||||
logger.info("USE_TORCH override through env variable, disabling PyTorch")
|
||||
_torch_available = False
|
||||
else:
|
||||
import torch
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
|
||||
else:
|
||||
logger.info("USE_TORCH override through env variable, disabling PyTorch")
|
||||
_torch_available = False
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user