From 24df44d9c731885d0f7d0ca0f3a74e70d7099e3b Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 7 Jan 2020 15:53:42 +0100 Subject: [PATCH] Black version python 3.5 --- src/transformers/tokenization_xlm.py | 35 +++++++--------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py index 85cde05f4ef..676321a164e 100644 --- a/src/transformers/tokenization_xlm.py +++ b/src/transformers/tokenization_xlm.py @@ -589,12 +589,8 @@ class XLMTokenizer(PreTrainedTokenizer): **kwargs, ) - self.max_len_single_sentence = ( - self.max_len - 2 - ) # take into account special tokens - self.max_len_sentences_pair = ( - self.max_len - 3 - ) # take into account special tokens + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() @@ -778,9 +774,7 @@ class XLMTokenizer(PreTrainedTokenizer): else: jieba = sys.modules["jieba"] except (AttributeError, ImportError): - logger.error( - "Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps" - ) + logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("1. pip install jieba") raise text = " ".join(jieba.cut(text)) @@ -829,9 +823,7 @@ class XLMTokenizer(PreTrainedTokenizer): cls = [self.cls_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep - def get_special_tokens_mask( - self, token_ids_0, token_ids_1=None, already_has_special_tokens=False - ): + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -853,12 +845,7 @@ class XLMTokenizer(PreTrainedTokenizer): "You should not supply a second sequence if the provided sequence of " "ids is already formated with special tokens for the model." ) - return list( - map( - lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, - token_ids_0, - ) - ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0,)) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] @@ -882,9 +869,7 @@ class XLMTokenizer(PreTrainedTokenizer): def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): - logger.error( - "Vocabulary path ({}) should be a directory".format(save_directory) - ) + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"]) @@ -894,15 +879,11 @@ class XLMTokenizer(PreTrainedTokenizer): index = 0 with open(merge_file, "w", encoding="utf-8") as writer: - for bpe_tokens, token_index in sorted( - self.bpe_ranks.items(), key=lambda kv: kv[1] - ): + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning( "Saving vocabulary to {}: BPE merge indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format( - merge_file - ) + " Please check that the tokenizer is not corrupted!".format(merge_file) ) index = token_index writer.write(" ".join(bpe_tokens) + "\n")