diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index ded1f7093bb..7be5031d0eb 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -3,6 +3,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_gpt2 import GPT2Tokenizer +from .tokenization_xlnet import XLNetTokenizer from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, @@ -24,4 +25,5 @@ from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig, from .optimization import BertAdam from .optimization_openai import OpenAIAdam -from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME +from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path, + WEIGHTS_NAME, CONFIG_NAME) diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index 39a2d95a4fa..08b193acfdd 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -1034,14 +1034,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): Only used during pretraining for two-stream attention. Set to None during finetuning. - mem_len: int, the number of tokens to cache. - reuse_len: int, the number of tokens in the currect batch to be cached - and reused in the future. - bi_data: bool, whether to use bidirectional input pipeline. - Usually set to True during pretraining and False during finetuning. - clamp_len: int, clamp all relative distances larger than clamp_len. - -1 means no clamping. - same_length: bool, whether to use the same attention length for each token. summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. """ @@ -1068,4 +1060,4 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): # encoded_layers = encoded_layers[-1] # if self.output_attentions: # return all_attentions, encoded_layers, pooled_output - return output, new_mems + return logits, new_mems diff --git a/pytorch_pretrained_bert/modeling_xlnet_utilities.py b/pytorch_pretrained_bert/modeling_xlnet_utilities.py new file mode 100644 index 00000000000..e2611b7a41e --- /dev/null +++ b/pytorch_pretrained_bert/modeling_xlnet_utilities.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Utilities for PyTorch XLNet model. +""" + +from collections import defaultdict + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +special_symbols = { + "" : 0, + "" : 1, + "" : 2, + "" : 3, + "" : 4, + "" : 5, + "" : 6, + "" : 7, + "" : 8, +} + +VOCAB_SIZE = 32000 +UNK_ID = special_symbols[""] +CLS_ID = special_symbols[""] +SEP_ID = special_symbols[""] +MASK_ID = special_symbols[""] +EOD_ID = special_symbols[""] + + +def permutation_mask(inputs, targets, is_masked, perm_size, seq_len): + """ + Sample a permutation of the factorization order, and create an + attention mask accordingly. + Args: + inputs: int64 Tensor in shape [seq_len], input ids. + targets: int64 Tensor in shape [seq_len], target ids. + is_masked: bool Tensor in shape [seq_len]. True means being selected + for partial prediction. + perm_size: the length of longest permutation. Could be set to be reuse_len. + Should not be larger than reuse_len or there will be data leaks. + seq_len: int, sequence length. + """ + + # Generate permutation indices + index = np.arange(10) + index = np.transpose(np.reshape(index, [-1, perm_size])) + index = np.random.shuffle(index) + index = np.reshape(np.transpose(index), [-1]) + + # `perm_mask` and `target_mask` + # non-functional tokens + non_func_tokens = tf.logical_not(tf.logical_or( + tf.equal(inputs, SEP_ID), + tf.equal(inputs, CLS_ID))) + + non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) + masked_or_func_tokens = tf.logical_not(non_mask_tokens) + + # Set the permutation indices of non-masked (& non-funcional) tokens to the + # smallest index (-1): + # (1) they can be seen by all other positions + # (2) they cannot see masked positions, so there won"t be information leak + smallest_index = -tf.ones([seq_len], dtype=tf.int64) + rev_index = tf.where(non_mask_tokens, smallest_index, index) + + # Create `target_mask`: non-funcional and maksed tokens + # 1: use mask as input and have loss + # 0: use token (or [SEP], [CLS]) as input and do not have loss + target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) + target_mask = tf.cast(target_tokens, tf.float32) + + # Create `perm_mask` + # `target_tokens` cannot see themselves + self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) + + # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) + # 0: can attend if i > j or j is non-masked + perm_mask = tf.logical_and( + self_rev_index[:, None] <= rev_index[None, :], + masked_or_func_tokens) + perm_mask = tf.cast(perm_mask, tf.float32) + + # new target: [next token] for LM and [curr token] (self) for PLM + new_targets = tf.concat([inputs[0: 1], targets[: -1]], + axis=0) + + # construct inputs_k + inputs_k = inputs + + # construct inputs_q + inputs_q = target_mask + + return perm_mask, new_targets, target_mask, inputs_k, inputs_q + diff --git a/pytorch_pretrained_bert/tokenization_xlnet.py b/pytorch_pretrained_bert/tokenization_xlnet.py index e69de29bb2d..c9a3d40631d 100644 --- a/pytorch_pretrained_bert/tokenization_xlnet.py +++ b/pytorch_pretrained_bert/tokenization_xlnet.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization classes for XLNet model.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import json +import logging +import os +import sys +from shutil import copyfile +from io import open + +import unicodedata +import six + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", +} +VOCAB_NAME = 'spiece.model' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + +SPIECE_UNDERLINE = '▁' + +class XLNetTokenizer(object): + """ + SentencePiece based tokenizer. Peculiarities: + - requires SentencePiece: https://github.com/google/sentencepiece + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + logger.error( + "Couldn't reach server at '{}' to download vocabulary.".format( + vocab_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {}" + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, special_tokens=special_tokens, *inputs, **kwargs) + return tokenizer + + def __init__(self, vocab_file, special_tokens=None, max_len=None, + do_lower_case=False, remove_space=True, keep_accents=False): + try: + import sentencepiece as spm + except ImportError: + logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" + "pip install sentencepiece") + + self.max_len = max_len if max_len is not None else int(1e12) + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.sp_model) + i) for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} + logger.info("Special tokens: %s", str(self.special_tokens)) + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = ' '.join(inputs.strip().split()) + else: + outputs = inputs + outputs = outputs.replace("``", '"').replace("''", '"') + + if six.PY2 and isinstance(outputs, str): + outputs = outputs.decode('utf-8') + + if not self.keep_accents: + outputs = unicodedata.normalize('NFKD', outputs) + outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def tokenize(self, text, return_unicode=True, sample=False): + """ Tokenize a string. + return_unicode is used only for py2 + """ + text = self.preprocess_text(text) + # note(zhiliny): in some systems, sentencepiece only accepts str for py2 + if six.PY2 and isinstance(text, unicode): + text = text.encode('utf-8') + + if not sample: + pieces = self.sp_model.EncodeAsPieces(text) + else: + pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) + new_pieces = [] + for piece in pieces: + if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): + cur_pieces = self.sp_model.EncodeAsPieces( + piece[:-1].replace(SPIECE_UNDERLINE, '')) + if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: + if len(cur_pieces[0]) == 1: + cur_pieces = cur_pieces[1:] + else: + cur_pieces[0] = cur_pieces[0][1:] + cur_pieces.append(piece[-1]) + new_pieces.extend(cur_pieces) + else: + new_pieces.append(piece) + + # note(zhiliny): convert back to unicode for py2 + if six.PY2 and return_unicode: + ret_pieces = [] + for piece in new_pieces: + if isinstance(piece, str): + piece = piece.decode('utf-8') + ret_pieces.append(piece) + new_pieces = ret_pieces + + return new_pieces + + def convert_tokens_to_ids(self, tokens, sample=False): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.sp_model.PieceToId(tokens) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.sp_model.PieceToId(token)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this XLNet model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in tokens.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.sp_model.IdToPiece(i)) + return tokens + + def encode(self, text, sample=False): + return self.convert_tokens_to_ids(self.tokenize(text, sample=sample)) + + def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + """Converts a sequence of ids in a string.""" + tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) + out_string = ''.join(tokens) + if clean_up_tokenization_spaces: + out_string = out_string.strip().replace('', '') + out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' + ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" + ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") + return out_string + + def save_vocabulary(self, vocab_path): + """ Save the sentencepiece vocabulary (copy original file) and special tokens file + to a directory. + """ + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + out_vocab_file = os.path.join(vocab_path, VOCAB_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + copyfile(self.vocab_file, out_vocab_file) + + index = len(self.sp_model) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return out_vocab_file, special_tokens_file diff --git a/requirements.txt b/requirements.txt index caf6471e86a..165fa74af9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,6 @@ boto3 # Used for downloading models over HTTP requests # For OpenAI GPT -regex \ No newline at end of file +regex +# For XLNet +sentencepiece \ No newline at end of file diff --git a/samples/test_sentencepiece.model b/samples/test_sentencepiece.model new file mode 100644 index 00000000000..376dda73010 Binary files /dev/null and b/samples/test_sentencepiece.model differ diff --git a/setup.py b/setup.py index fe7990447df..28e85a00688 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,8 @@ setup( 'boto3', 'requests', 'tqdm', - 'regex'], + 'regex', + 'sentencepiece'], entry_points={ 'console_scripts': [ "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main", diff --git a/tests/modeling_xlnet_test.py b/tests/modeling_xlnet_test.py index 30a6bfbec73..ec9cc2cda1d 100644 --- a/tests/modeling_xlnet_test.py +++ b/tests/modeling_xlnet_test.py @@ -35,8 +35,8 @@ class XLNetModelTest(unittest.TestCase): parent, batch_size=13, seq_length=7, - mem_len=30, - clamp_len=15, + mem_len=10, + clamp_len=-1, reuse_len=15, is_training=True, use_labels=True, @@ -78,6 +78,27 @@ class XLNetModelTest(unittest.TestCase): input_ids_2 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size) segment_ids = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.type_vocab_size) + # inp_k: int32 Tensor in shape [len, bsz], the input token IDs. + # seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. + # input_mask: float32 Tensor in shape [len, bsz], the input mask. + # 0 for real tokens and 1 for padding. + # mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory + # from previous batches. The length of the list equals n_layer. + # If None, no memory is used. + # perm_mask: float32 Tensor in shape [len, len, bsz]. + # If perm_mask[i, j, k] = 0, i attend to j in batch k; + # if perm_mask[i, j, k] = 1, i does not attend to j in batch k. + # If None, each position attends to all the others. + # target_mapping: float32 Tensor in shape [num_predict, len, bsz]. + # If target_mapping[i, j, k] = 1, the i-th predict in batch k is + # on the j-th token. + # Only used during pretraining for partial prediction. + # Set to None during finetuning. + # inp_q: float32 Tensor in shape [len, bsz]. + # 1 for tokens with losses and 0 for tokens without losses. + # Only used during pretraining for two-stream attention. + # Set to None during finetuning. + lm_labels = None if self.use_labels: lm_labels = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size) @@ -106,44 +127,15 @@ class XLNetModelTest(unittest.TestCase): random.seed(self.seed) torch.manual_seed(self.seed) - def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels): - model = XLNetLMHeadModel(config) - model.eval() - - hidden_states_1, mems_1 = model(input_ids_1, seg_id=segment_ids) - hidden_states_2, mems_2 = model(input_ids_2, seg_id=segment_ids, mems=mems_1) - outputs = { - "hidden_states_1": hidden_states_1, - "mems_1": mems_1, - "hidden_states_2": hidden_states_2, - "mems_2": mems_2, - } - return outputs - - def check_transfo_xl_model_output(self, result): - self.parent.assertListEqual( - list(result["hidden_states_1"].size()), - [self.seq_length, self.batch_size, self.d_model]) - self.parent.assertListEqual( - list(result["hidden_states_2"].size()), - [self.seq_length, self.batch_size, self.d_model]) - self.parent.assertListEqual( - list(list(mem.size()) for mem in result["mems_1"]), - [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) - self.parent.assertListEqual( - list(list(mem.size()) for mem in result["mems_2"]), - [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) - - def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels): model = XLNetLMHeadModel(config) model.eval() - loss_1, mems_1a = model(input_ids_1, target=lm_labels) - lm_logits_1, mems_1b = model(input_ids_1) + loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels) + lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids) - loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a) - lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b) + loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a) + lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b) outputs = { "loss_1": loss_1, @@ -160,23 +152,23 @@ class XLNetModelTest(unittest.TestCase): def check_transfo_xl_lm_head_output(self, result): self.parent.assertListEqual( list(result["loss_1"].size()), - [self.seq_length, self.batch_size]) + []) self.parent.assertListEqual( list(result["lm_logits_1"].size()), [self.seq_length, self.batch_size, self.vocab_size]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1a"]), - [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) + [[self.seq_length, self.batch_size, self.d_model]] * self.n_layer) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1b"]), - [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) + [[self.seq_length, self.batch_size, self.d_model]] * self.n_layer) self.parent.assertListEqual( list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]), list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"])) self.parent.assertListEqual( list(result["loss_2"].size()), - [self.seq_length, self.batch_size]) + []) self.parent.assertListEqual( list(result["lm_logits_2"].size()), [self.seq_length, self.batch_size, self.vocab_size]) @@ -218,10 +210,6 @@ class XLNetModelTest(unittest.TestCase): def run_tester(self, tester): config_and_inputs = tester.prepare_config_and_inputs() - tester.set_seed() - output_result = tester.create_transfo_xl_model(*config_and_inputs) - tester.check_transfo_xl_model_output(output_result) - tester.set_seed() output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) tester.check_transfo_xl_lm_head_output(output_result) @@ -242,6 +230,22 @@ class XLNetModelTest(unittest.TestCase): return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() + @classmethod + def mask_tensor(cls, shape, vocab_size, rng=None, name=None): + """Creates a tensor with padding on the right (0.0 for ).""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() + if __name__ == "__main__": unittest.main() diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index fe120a522c3..249f71f9842 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -49,7 +49,7 @@ class TokenizationTest(unittest.TestCase): tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") - tokenizer.from_pretrained(vocab_file) + tokenizer = tokenizer.from_pretrained(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") diff --git a/tests/tokenization_transfo_xl_test.py b/tests/tokenization_transfo_xl_test.py index bf0ac5db2fe..226db4598e5 100644 --- a/tests/tokenization_transfo_xl_test.py +++ b/tests/tokenization_transfo_xl_test.py @@ -44,7 +44,7 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") - tokenizer.from_pretrained(vocab_file) + tokenizer = tokenizer.from_pretrained(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u" UNwanted , running") diff --git a/tests/tokenization_xlnet_test.py b/tests/tokenization_xlnet_test.py new file mode 100644 index 00000000000..e383dd78773 --- /dev/null +++ b/tests/tokenization_xlnet_test.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +from io import open +import shutil +import pytest + +from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP) + +SAMPLE_VOCAB = os.path.join(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))), + 'samples/test_sentencepiece.model') + +class XLNetTokenizationTest(unittest.TestCase): + + def test_full_tokenizer(self): + tokenizer = XLNetTokenizer(SAMPLE_VOCAB) + + tokens = tokenizer.tokenize('This is a test') + self.assertListEqual(tokens, ['▁This', '▁is', '▁a', '▁t', 'est']) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) + + vocab_path = "/tmp/" + vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path) + tokenizer = tokenizer.from_pretrained(vocab_path, + keep_accents=True) + os.remove(vocab_file) + os.remove(special_tokens_file) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '▁', + '9', '2', '0', '0', '0', ',', '▁and', '▁this', + '▁is', '▁f', 'al', 's', 'é', '.']) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, [8, 21, 84, 55, 24, 19, 7, 0, + 602, 347, 347, 347, 3, 12, 66, + 46, 72, 80, 6, 0, 4]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual(back_tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', + '▁', '', '2', '0', '0', '0', ',', + '▁and', '▁this', '▁is', '▁f', 'al', 's', + '', '.']) + + @pytest.mark.slow + def test_tokenizer_from_pretrained(self): + cache_dir = "/tmp/pytorch_pretrained_bert_test/" + for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: + tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir) + shutil.rmtree(cache_dir) + self.assertIsNotNone(tokenizer) + + def test_tokenizer_lower(self): + tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) + tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") + self.assertListEqual(tokens, ['▁', 'i', '▁was', '▁b', 'or', 'n', '▁in', '▁', + '9', '2', '0', '0', '0', ',', '▁and', '▁this', + '▁is', '▁f', 'al', 'se', '.']) + self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["▁he", "ll", "o"]) + + def test_tokenizer_no_lower(self): + tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) + tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") + self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '▁', + '9', '2', '0', '0', '0', ',', '▁and', '▁this', + '▁is', '▁f', 'al', 'se', '.']) + + +if __name__ == '__main__': + unittest.main()