mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Merge pull request #2288 from aaugustin/better-handle-optional-imports
Improve handling of optional imports
This commit is contained in:
commit
072750f4dc
@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
if not _serve_dependancies_installed:
|
if not _serve_dependancies_installed:
|
||||||
raise ImportError(
|
raise RuntimeError(
|
||||||
"Using serve command requires FastAPI and unicorn. "
|
"Using serve command requires FastAPI and unicorn. "
|
||||||
"Please install transformers with [serving]: pip install transformers[serving]."
|
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||||
"Or install FastAPI and unicorn separatly."
|
"Or install FastAPI and unicorn separatly."
|
||||||
|
@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand
|
|||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||||
|
|
||||||
# TF training parameters
|
# TF training parameters
|
||||||
USE_XLA = False
|
USE_XLA = False
|
||||||
|
@ -14,18 +14,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from scipy.stats import pearsonr, spearmanr
|
from scipy.stats import pearsonr, spearmanr
|
||||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||||
|
|
||||||
_has_sklearn = True
|
_has_sklearn = True
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
|
||||||
_has_sklearn = False
|
_has_sklearn = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,7 +324,7 @@ def squad_convert_examples_to_features(
|
|||||||
del new_features
|
del new_features
|
||||||
if return_dataset == "pt":
|
if return_dataset == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
|
||||||
|
|
||||||
# Convert to Tensors and build dataset
|
# Convert to Tensors and build dataset
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
@ -354,7 +354,7 @@ def squad_convert_examples_to_features(
|
|||||||
return features, dataset
|
return features, dataset
|
||||||
elif return_dataset == "tf":
|
elif return_dataset == "tf":
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
|
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
|
@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
return features
|
return features
|
||||||
elif return_tensors == "tf":
|
elif return_tensors == "tf":
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
|||||||
return dataset
|
return dataset
|
||||||
elif return_tensors == "pt":
|
elif return_tensors == "pt":
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
|
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
|
@ -76,12 +76,12 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||||
@ -111,12 +111,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
if tf_inputs is None:
|
if tf_inputs is None:
|
||||||
tf_inputs = tf_model.dummy_inputs
|
tf_inputs = tf_model.dummy_inputs
|
||||||
@ -209,12 +209,12 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
@ -251,12 +251,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
new_pt_params_dict = {}
|
new_pt_params_dict = {}
|
||||||
current_pt_params_dict = dict(pt_model.named_parameters())
|
current_pt_params_dict = dict(pt_model.named_parameters())
|
||||||
|
@ -454,12 +454,12 @@ class PreTrainedModel(nn.Module):
|
|||||||
from transformers import load_tf2_checkpoint_in_pytorch_model
|
from transformers import load_tf2_checkpoint_in_pytorch_model
|
||||||
|
|
||||||
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
|
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise e
|
raise
|
||||||
else:
|
else:
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
|
@ -68,7 +68,7 @@ def get_framework(model=None):
|
|||||||
# Try to guess which framework to use from the model classname
|
# Try to guess which framework to use from the model classname
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
elif not is_tf_available() and not is_torch_available():
|
elif not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError(
|
raise RuntimeError(
|
||||||
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||||
|
@ -100,6 +100,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
|
|||||||
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
|
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.remove_space = remove_space
|
self.remove_space = remove_space
|
||||||
@ -127,6 +128,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
|
|||||||
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
|
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
self.sp_model = spm.SentencePieceProcessor()
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
@ -107,6 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
"https://github.com/google/sentencepiece"
|
"https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self._extra_ids = extra_ids
|
self._extra_ids = extra_ids
|
||||||
@ -132,6 +133,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
self.sp_model = spm.SentencePieceProcessor()
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
@ -26,14 +26,12 @@ from collections import Counter, OrderedDict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, is_torch_available
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
try:
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -646,7 +646,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.ja_word_tokenizer = Mykytea.Mykytea(
|
self.ja_word_tokenizer = Mykytea.Mykytea(
|
||||||
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
|
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
|
||||||
)
|
)
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
|
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
|
||||||
)
|
)
|
||||||
@ -655,7 +655,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
logger.error("3. ./configure --prefix=$HOME/local")
|
logger.error("3. ./configure --prefix=$HOME/local")
|
||||||
logger.error("4. make && make install")
|
logger.error("4. make && make install")
|
||||||
logger.error("5. pip install kytea")
|
logger.error("5. pip install kytea")
|
||||||
raise e
|
raise
|
||||||
return list(self.ja_word_tokenizer.getWS(text))
|
return list(self.ja_word_tokenizer.getWS(text))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -760,12 +760,12 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||||||
else:
|
else:
|
||||||
th_word_tokenize = sys.modules["pythainlp"].word_tokenize
|
th_word_tokenize = sys.modules["pythainlp"].word_tokenize
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
|
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
|
||||||
)
|
)
|
||||||
logger.error("1. pip install pythainlp")
|
logger.error("1. pip install pythainlp")
|
||||||
raise e
|
raise
|
||||||
text = th_word_tokenize(text)
|
text = th_word_tokenize(text)
|
||||||
elif lang == "zh":
|
elif lang == "zh":
|
||||||
try:
|
try:
|
||||||
@ -773,10 +773,10 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
import jieba
|
import jieba
|
||||||
else:
|
else:
|
||||||
jieba = sys.modules["jieba"]
|
jieba = sys.modules["jieba"]
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError):
|
||||||
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||||||
logger.error("1. pip install jieba")
|
logger.error("1. pip install jieba")
|
||||||
raise e
|
raise
|
||||||
text = " ".join(jieba.cut(text))
|
text = " ".join(jieba.cut(text))
|
||||||
text = self.moses_pipeline(text, lang=lang)
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
text = text.split()
|
text = text.split()
|
||||||
|
@ -100,6 +100,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.remove_space = remove_space
|
self.remove_space = remove_space
|
||||||
@ -127,6 +128,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece"
|
"pip install sentencepiece"
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
self.sp_model = spm.SentencePieceProcessor()
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user