Merge pull request #2288 from aaugustin/better-handle-optional-imports

Improve handling of optional imports
This commit is contained in:
Aymeric Augustin 2019-12-23 22:28:47 +01:00 committed by GitHub
commit 072750f4dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 31 additions and 33 deletions

View File

@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand):
self._host = host
self._port = port
if not _serve_dependancies_installed:
raise ImportError(
raise RuntimeError(
"Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly."

View File

@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
# TF training parameters
USE_XLA = False

View File

@ -14,18 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
try:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True
except (AttributeError, ImportError) as e:
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
_has_sklearn = False

View File

@ -324,7 +324,7 @@ def squad_convert_examples_to_features(
del new_features
if return_dataset == "pt":
if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
@ -354,7 +354,7 @@ def squad_convert_examples_to_features(
return features, dataset
elif return_dataset == "tf":
if not is_tf_available():
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen():
for ex in features:

View File

@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return features
elif return_tensors == "tf":
if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf
def gen():
@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return dataset
elif return_tensors == "pt":
if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch
from torch.utils.data import TensorDataset

View File

@ -76,12 +76,12 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError as e:
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
@ -111,12 +111,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
import torch # noqa: F401
import tensorflow as tf # noqa: F401
from tensorflow.python.keras import backend as K
except ImportError as e:
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
raise
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
@ -209,12 +209,12 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError as e:
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
raise
import transformers
@ -251,12 +251,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError as e:
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
raise
new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters())

View File

@ -454,12 +454,12 @@ class PreTrainedModel(nn.Module):
from transformers import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
except ImportError as e:
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
raise
else:
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []

View File

@ -68,7 +68,7 @@ def get_framework(model=None):
# Try to guess which framework to use from the model classname
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
elif not is_tf_available() and not is_torch_available():
raise ImportError(
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."

View File

@ -100,6 +100,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.do_lower_case = do_lower_case
self.remove_space = remove_space
@ -127,6 +128,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)

View File

@ -107,6 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
"https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.vocab_file = vocab_file
self._extra_ids = extra_ids
@ -132,6 +133,7 @@ class T5Tokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)

View File

@ -26,14 +26,12 @@ from collections import Counter, OrderedDict
import numpy as np
from .file_utils import cached_path
from .file_utils import cached_path, is_torch_available
from .tokenization_utils import PreTrainedTokenizer
try:
if is_torch_available():
import torch
except ImportError:
pass
logger = logging.getLogger(__name__)

View File

@ -646,7 +646,7 @@ class XLMTokenizer(PreTrainedTokenizer):
self.ja_word_tokenizer = Mykytea.Mykytea(
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
)
except (AttributeError, ImportError) as e:
except (AttributeError, ImportError):
logger.error(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
)
@ -655,7 +655,7 @@ class XLMTokenizer(PreTrainedTokenizer):
logger.error("3. ./configure --prefix=$HOME/local")
logger.error("4. make && make install")
logger.error("5. pip install kytea")
raise e
raise
return list(self.ja_word_tokenizer.getWS(text))
@property
@ -760,12 +760,12 @@ class XLMTokenizer(PreTrainedTokenizer):
from pythainlp.tokenize import word_tokenize as th_word_tokenize
else:
th_word_tokenize = sys.modules["pythainlp"].word_tokenize
except (AttributeError, ImportError) as e:
except (AttributeError, ImportError):
logger.error(
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
)
logger.error("1. pip install pythainlp")
raise e
raise
text = th_word_tokenize(text)
elif lang == "zh":
try:
@ -773,10 +773,10 @@ class XLMTokenizer(PreTrainedTokenizer):
import jieba
else:
jieba = sys.modules["jieba"]
except (AttributeError, ImportError) as e:
except (AttributeError, ImportError):
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba")
raise e
raise
text = " ".join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang)
text = text.split()

View File

@ -100,6 +100,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.do_lower_case = do_lower_case
self.remove_space = remove_space
@ -127,6 +128,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)