mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Black version python 3.5
This commit is contained in:
parent
73be60c47b
commit
24df44d9c7
@ -589,12 +589,8 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_len_single_sentence = (
|
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||||
self.max_len - 2
|
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||||
) # take into account special tokens
|
|
||||||
self.max_len_sentences_pair = (
|
|
||||||
self.max_len - 3
|
|
||||||
) # take into account special tokens
|
|
||||||
|
|
||||||
# cache of sm.MosesPunctNormalizer instance
|
# cache of sm.MosesPunctNormalizer instance
|
||||||
self.cache_moses_punct_normalizer = dict()
|
self.cache_moses_punct_normalizer = dict()
|
||||||
@ -778,9 +774,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
jieba = sys.modules["jieba"]
|
jieba = sys.modules["jieba"]
|
||||||
except (AttributeError, ImportError):
|
except (AttributeError, ImportError):
|
||||||
logger.error(
|
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||||||
"Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps"
|
|
||||||
)
|
|
||||||
logger.error("1. pip install jieba")
|
logger.error("1. pip install jieba")
|
||||||
raise
|
raise
|
||||||
text = " ".join(jieba.cut(text))
|
text = " ".join(jieba.cut(text))
|
||||||
@ -829,9 +823,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_special_tokens_mask(
|
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||||
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
@ -853,12 +845,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
"You should not supply a second sequence if the provided sequence of "
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
"ids is already formated with special tokens for the model."
|
"ids is already formated with special tokens for the model."
|
||||||
)
|
)
|
||||||
return list(
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0,))
|
||||||
map(
|
|
||||||
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
|
||||||
token_ids_0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
@ -882,9 +869,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error(
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
"Vocabulary path ({}) should be a directory".format(save_directory)
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
|
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
|
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
|
||||||
@ -894,15 +879,11 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
for bpe_tokens, token_index in sorted(
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||||
self.bpe_ranks.items(), key=lambda kv: kv[1]
|
|
||||||
):
|
|
||||||
if index != token_index:
|
if index != token_index:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||||
" Please check that the tokenizer is not corrupted!".format(
|
" Please check that the tokenizer is not corrupted!".format(merge_file)
|
||||||
merge_file
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
index = token_index
|
index = token_index
|
||||||
writer.write(" ".join(bpe_tokens) + "\n")
|
writer.write(" ".join(bpe_tokens) + "\n")
|
||||||
|
Loading…
Reference in New Issue
Block a user