mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Improve exception type.
ImportError isn't really appropriate when there's no import involved.
This commit is contained in:
parent
4c09a96096
commit
c8b0c1e551
@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
if not _serve_dependancies_installed:
|
if not _serve_dependancies_installed:
|
||||||
raise ImportError(
|
raise RuntimeError(
|
||||||
"Using serve command requires FastAPI and unicorn. "
|
"Using serve command requires FastAPI and unicorn. "
|
||||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||||
"Or install FastAPI and unicorn separatly."
|
"Or install FastAPI and unicorn separatly."
|
||||||
|
@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand
|
|||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||||
|
|
||||||
# TF training parameters
|
# TF training parameters
|
||||||
USE_XLA = False
|
USE_XLA = False
|
||||||
|
@ -324,7 +324,7 @@ def squad_convert_examples_to_features(
|
|||||||
del new_features
|
del new_features
|
||||||
if return_dataset == "pt":
|
if return_dataset == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
|
||||||
|
|
||||||
# Convert to Tensors and build dataset
|
# Convert to Tensors and build dataset
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
@ -354,7 +354,7 @@ def squad_convert_examples_to_features(
|
|||||||
return features, dataset
|
return features, dataset
|
||||||
elif return_dataset == "tf":
|
elif return_dataset == "tf":
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
|
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
|
@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
return features
|
return features
|
||||||
elif return_tensors == "tf":
|
elif return_tensors == "tf":
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
return dataset
|
return dataset
|
||||||
elif return_tensors == "pt":
|
elif return_tensors == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
|
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ def get_framework(model=None):
|
|||||||
# Try to guess which framework to use from the model classname
|
# Try to guess which framework to use from the model classname
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
elif not is_tf_available() and not is_torch_available():
|
elif not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError(
|
raise RuntimeError(
|
||||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||||
|
Loading…
Reference in New Issue
Block a user