Make the sacremoses dependency optional (#17049)

* Make sacremoses optional

* Pickle
This commit is contained in:
Lysandre Debut 2022-05-02 13:47:47 -03:00 committed by GitHub
parent bb2e088be7
commit 30ca529902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 12 deletions

View File

@ -288,6 +288,7 @@ extras["testing"] = (
"nltk",
"GitPython",
"hf-doc-builder",
'sacremoses'
)
+ extras["retrieval"]
+ extras["modelcreation"]
@ -365,7 +366,6 @@ extras["torchhub"] = deps_list(
"protobuf",
"regex",
"requests",
"sacremoses",
"sentencepiece",
"torch",
"tokenizers",
@ -383,7 +383,6 @@ install_requires = [
deps["pyyaml"], # used for the model cards metadata
deps["regex"], # for OpenAI GPT
deps["requests"], # for downloading models over HTTPS
deps["sacremoses"], # for XLM
deps["tokenizers"],
deps["tqdm"], # progress bars in model download and training scripts
]

View File

@ -23,7 +23,7 @@ from .utils.versions import require_version, require_version_core
# order specific notes:
# - tqdm must be checked before tokenizers
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
if sys.version_info < (3, 7):
pkgs_to_check_at_runtime.append("dataclasses")
if sys.version_info < (3, 8):

View File

@ -21,8 +21,6 @@ import re
import unicodedata
from typing import Dict, List, Optional, Tuple
import sacremoses as sm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@ -212,6 +210,16 @@ class FSMTTokenizer(PreTrainedTokenizer):
**kwargs,
)
try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self.sm = sacremoses
self.src_vocab_file = src_vocab_file
self.tgt_vocab_file = tgt_vocab_file
self.merges_file = merges_file
@ -254,13 +262,13 @@ class FSMTTokenizer(PreTrainedTokenizer):
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
return self.cache_moses_punct_normalizer[lang].normalize(text)
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
return self.cache_moses_tokenizer[lang].tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=True
@ -268,7 +276,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
def moses_detokenize(self, tokens, lang):
if lang not in self.cache_moses_tokenizer:
moses_detokenizer = sm.MosesDetokenizer(lang=self.tgt_lang)
moses_detokenizer = self.sm.MosesDetokenizer(lang=self.tgt_lang)
self.cache_moses_detokenizer[lang] = moses_detokenizer
return self.cache_moses_detokenizer[lang].detokenize(tokens)
@ -516,3 +524,21 @@ class FSMTTokenizer(PreTrainedTokenizer):
index += 1
return src_vocab_file, tgt_vocab_file, merges_file
def __getstate__(self):
state = self.__dict__.copy()
state["sm"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self.sm = sacremoses

View File

@ -22,8 +22,6 @@ import sys
import unicodedata
from typing import List, Optional, Tuple
import sacremoses as sm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@ -629,6 +627,16 @@ class XLMTokenizer(PreTrainedTokenizer):
**kwargs,
)
try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self.sm = sacremoses
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
@ -659,7 +667,7 @@ class XLMTokenizer(PreTrainedTokenizer):
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
@ -667,7 +675,7 @@ class XLMTokenizer(PreTrainedTokenizer):
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
@ -970,3 +978,21 @@ class XLMTokenizer(PreTrainedTokenizer):
index += 1
return vocab_file, merge_file
def __getstate__(self):
state = self.__dict__.copy()
state["sm"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self.sm = sacremoses