fix xlnet tokenizer and python2

This commit is contained in:
thomwolf 2019-06-22 22:28:49 +02:00
parent 181075635d
commit c946bb51a6
2 changed files with 32 additions and 19 deletions

View File

@ -241,7 +241,7 @@ class XLNetTokenizer(object):
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
def convert_ids_to_tokens(self, ids, return_unicode=True, skip_special_tokens=False):
"""Converts a sequence of ids in tokens."""
tokens = []
for i in ids:
@ -250,6 +250,14 @@ class XLNetTokenizer(object):
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.sp_model.IdToPiece(i))
if six.PY2 and return_unicode:
ret_pieces = []
for piece in tokens:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
tokens = ret_pieces
return tokens
def encode(self, text, sample=False):

View File

@ -33,23 +33,24 @@ class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize('This is a test')
self.assertListEqual(tokens, ['▁This', '▁is', '▁a', '▁t', 'est'])
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
vocab_path = "/tmp/"
vocab_path = u"/tmp/"
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True)
os.remove(vocab_file)
os.remove(special_tokens_file)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 's', 'é', '.'])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
@ -57,10 +58,12 @@ class XLNetTokenizationTest(unittest.TestCase):
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in',
SPIECE_UNDERLINE + '', '<unk>', '2', '0', '0', '0', ',',
SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this', SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 's',
'<unk>', '.'])
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
@ -73,17 +76,19 @@ class XLNetTokenizationTest(unittest.TestCase):
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + '', 'i', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 'se', '.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["▁he", "ll", "o"])
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 'se', '.'])
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
if __name__ == '__main__':