This commit is contained in:
Lysandre Debut 2020-01-07 15:34:23 +01:00
parent 6806f8204e
commit 73be60c47b

View File

@ -474,7 +474,7 @@ def replace_unicode_punct(text):
text = text.replace("", "!") text = text.replace("", "!")
text = text.replace("", "(") text = text.replace("", "(")
text = text.replace("", ";") text = text.replace("", ";")
text = text.replace("", '1') text = text.replace("", "1")
text = text.replace("", '"') text = text.replace("", '"')
text = text.replace("", '"') text = text.replace("", '"')
text = text.replace("", "0") text = text.replace("", "0")
@ -589,8 +589,12 @@ class XLMTokenizer(PreTrainedTokenizer):
**kwargs, **kwargs,
) )
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len - 2
) # take into account special tokens
self.max_len_sentences_pair = (
self.max_len - 3
) # take into account special tokens
# cache of sm.MosesPunctNormalizer instance # cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict() self.cache_moses_punct_normalizer = dict()
@ -774,7 +778,9 @@ class XLMTokenizer(PreTrainedTokenizer):
else: else:
jieba = sys.modules["jieba"] jieba = sys.modules["jieba"]
except (AttributeError, ImportError): except (AttributeError, ImportError):
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error(
"Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps"
)
logger.error("1. pip install jieba") logger.error("1. pip install jieba")
raise raise
text = " ".join(jieba.cut(text)) text = " ".join(jieba.cut(text))
@ -823,7 +829,9 @@ class XLMTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): def get_special_tokens_mask(
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -845,7 +853,12 @@ class XLMTokenizer(PreTrainedTokenizer):
"You should not supply a second sequence if the provided sequence of " "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model." "ids is already formated with special tokens for the model."
) )
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(
map(
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
token_ids_0,
)
)
if token_ids_1 is not None: if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
@ -869,7 +882,9 @@ class XLMTokenizer(PreTrainedTokenizer):
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error(
"Vocabulary path ({}) should be a directory".format(save_directory)
)
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
@ -879,11 +894,15 @@ class XLMTokenizer(PreTrainedTokenizer):
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(
self.bpe_ranks.items(), key=lambda kv: kv[1]
):
if index != token_index: if index != token_index:
logger.warning( logger.warning(
"Saving vocabulary to {}: BPE merge indices are not consecutive." "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file) " Please check that the tokenizer is not corrupted!".format(
merge_file
)
) )
index = token_index index = token_index
writer.write(" ".join(bpe_tokens) + "\n") writer.write(" ".join(bpe_tokens) + "\n")