add tokenizer and tests

This commit is contained in:
thomwolf 2019-06-21 11:09:51 +02:00
parent 45709d7532
commit 32da75486b
11 changed files with 511 additions and 57 deletions

View File

@ -3,6 +3,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
@ -24,4 +25,5 @@ from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path,
WEIGHTS_NAME, CONFIG_NAME)

View File

@ -1034,14 +1034,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
@ -1068,4 +1060,4 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return output, new_mems
return logits, new_mems

View File

@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch XLNet model.
"""
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
def permutation_mask(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected
for partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
"""
# Generate permutation indices
index = np.arange(10)
index = np.transpose(np.reshape(index, [-1, perm_size]))
index = np.random.shuffle(index)
index = np.reshape(np.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and maksed tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(
self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
axis=0)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
return perm_mask, new_targets, target_mask, inputs_k, inputs_q

View File

@ -0,0 +1,254 @@
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for XLNet model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os
import sys
from shutil import copyfile
from io import open
import unicodedata
import six
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
}
VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = ''
class XLNetTokenizer(object):
"""
SentencePiece based tokenizer. Peculiarities:
- requires SentencePiece: https://github.com/google/sentencepiece
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {}"
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, special_tokens=None, max_len=None,
do_lower_case=False, remove_space=True, keep_accents=False):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
self.max_len = max_len if max_len is not None else int(1e12)
self.do_lower_case = do_lower_case
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.sp_model) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens: %s", str(self.special_tokens))
def preprocess_text(self, inputs):
if self.remove_space:
outputs = ' '.join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
return outputs
def tokenize(self, text, return_unicode=True, sample=False):
""" Tokenize a string.
return_unicode is used only for py2
"""
text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
new_pieces = ret_pieces
return new_pieces
def convert_tokens_to_ids(self, tokens, sample=False):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.sp_model.PieceToId(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.sp_model.PieceToId(token))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this XLNet model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in tokens."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.sp_model.IdToPiece(i))
return tokens
def encode(self, text, sample=False):
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens)
if clean_up_tokenization_spaces:
out_string = out_string.strip().replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
def save_vocabulary(self, vocab_path):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
copyfile(self.vocab_file, out_vocab_file)
index = len(self.sp_model)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return out_vocab_file, special_tokens_file

View File

@ -7,4 +7,6 @@ boto3
# Used for downloading models over HTTP
requests
# For OpenAI GPT
regex
regex
# For XLNet
sentencepiece

Binary file not shown.

View File

@ -54,7 +54,8 @@ setup(
'boto3',
'requests',
'tqdm',
'regex'],
'regex',
'sentencepiece'],
entry_points={
'console_scripts': [
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",

View File

@ -35,8 +35,8 @@ class XLNetModelTest(unittest.TestCase):
parent,
batch_size=13,
seq_length=7,
mem_len=30,
clamp_len=15,
mem_len=10,
clamp_len=-1,
reuse_len=15,
is_training=True,
use_labels=True,
@ -78,6 +78,27 @@ class XLNetModelTest(unittest.TestCase):
input_ids_2 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.type_vocab_size)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
# from previous batches. The length of the list equals n_layer.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [len, len, bsz].
# If perm_mask[i, j, k] = 0, i attend to j in batch k;
# if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [num_predict, len, bsz].
# If target_mapping[i, j, k] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [len, bsz].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
lm_labels = None
if self.use_labels:
lm_labels = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
@ -106,44 +127,15 @@ class XLNetModelTest(unittest.TestCase):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
hidden_states_1, mems_1 = model(input_ids_1, seg_id=segment_ids)
hidden_states_2, mems_2 = model(input_ids_2, seg_id=segment_ids, mems=mems_1)
outputs = {
"hidden_states_1": hidden_states_1,
"mems_1": mems_1,
"hidden_states_2": hidden_states_2,
"mems_2": mems_2,
}
return outputs
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1)
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
outputs = {
"loss_1": loss_1,
@ -160,23 +152,23 @@ class XLNetModelTest(unittest.TestCase):
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(
list(result["loss_1"].size()),
[self.seq_length, self.batch_size])
[])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual(
list(result["loss_2"].size()),
[self.seq_length, self.batch_size])
[])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
@ -218,10 +210,6 @@ class XLNetModelTest(unittest.TestCase):
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed()
output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result)
tester.set_seed()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
@ -242,6 +230,22 @@ class XLNetModelTest(unittest.TestCase):
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
@classmethod
def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a tensor with padding on the right (0.0 for )."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__":
unittest.main()

View File

@ -49,7 +49,7 @@ class TokenizationTest(unittest.TestCase):
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
tokenizer = tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")

View File

@ -44,7 +44,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
tokenizer = tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")

View File

@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP)
SAMPLE_VOCAB = os.path.join(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))),
'samples/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize('This is a test')
self.assertListEqual(tokens, ['▁This', '▁is', '▁a', '▁t', 'est'])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
vocab_path = "/tmp/"
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True)
os.remove(vocab_file)
os.remove(special_tokens_file)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '',
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
'▁is', '▁f', 'al', 's', 'é', '.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in',
'', '<unk>', '2', '0', '0', '0', ',',
'▁and', '▁this', '▁is', '▁f', 'al', 's',
'<unk>', '.'])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, ['', 'i', '▁was', '▁b', 'or', 'n', '▁in', '',
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
'▁is', '▁f', 'al', 'se', '.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["▁he", "ll", "o"])
def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '',
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
'▁is', '▁f', 'al', 'se', '.'])
if __name__ == '__main__':
unittest.main()