mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Raise a TF-specific error when importing Torch classes (#18280)
* Raise a TF-specific error when importing Torch classes * Update src/transformers/utils/import_utils.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Add an inverse error for PyTorch users Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
5e0ffd9183
commit
a649de5551
@ -693,6 +693,30 @@ PYTORCH_IMPORT_ERROR = """
|
||||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = """
|
||||
{0} requires the PyTorch library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to our PyTorch classes. This
|
||||
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
||||
If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use PyTorch please go to
|
||||
https://pytorch.org/get-started/locally/ and follow the instructions that
|
||||
match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TF_IMPORT_ERROR_WITH_PYTORCH = """
|
||||
{0} requires the TensorFlow library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "TF", but are otherwise identically named to our TF classes.
|
||||
If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use TensorFlow, please follow the instructions on the
|
||||
installation page https://www.tensorflow.org/install that match your environment.
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
SKLEARN_IMPORT_ERROR = """
|
||||
@ -855,6 +879,15 @@ def requires_backends(obj, backends):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
|
||||
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
||||
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
|
||||
raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
|
||||
# Raise the inverse error for PyTorch users trying to load TF classes
|
||||
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
|
||||
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
|
Loading…
Reference in New Issue
Block a user