mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
standardizing tokenizers API and adding tests
This commit is contained in:
parent
c0239e09e6
commit
e75c3f70aa
@ -598,3 +598,9 @@ def prune_layer(layer, index, dim=None):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
return out_string
|
||||
|
@ -48,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
}
|
||||
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||
|
@ -26,6 +26,7 @@ from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class TokenizationTest(unittest.TestCase):
|
||||
|
||||
@ -36,28 +37,18 @@ class TokenizationTest(unittest.TestCase):
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer = tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
|
@ -22,6 +22,7 @@ import pytest
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class GPT2TokenizationTest(unittest.TestCase):
|
||||
|
||||
@ -39,10 +40,9 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
@ -53,17 +53,8 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
self.assertListEqual(
|
||||
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
|
@ -22,6 +22,8 @@ import pytest
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
|
||||
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
|
||||
@ -40,6 +42,8 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
@ -54,18 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
self.assertListEqual(
|
||||
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
|
81
pytorch_pretrained_bert/tests/tokenization_tests_commons.py
Normal file
81
pytorch_pretrained_bert/tests/tokenization_tests_commons.py
Normal file
@ -0,0 +1,81 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
unicode = str
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
vocab_path="/tmp/"
|
||||
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path)
|
||||
|
||||
for f in output_files:
|
||||
os.remove(f)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
tester.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
text = "Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
|
||||
tester.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
@ -22,6 +22,7 @@ import pytest
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
|
||||
@ -33,8 +34,9 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
tokenizer.build_vocab()
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
@ -43,17 +45,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer = tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
|
@ -22,6 +22,7 @@ import pytest
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class XLMTokenizationTest(unittest.TestCase):
|
||||
|
||||
@ -40,6 +41,8 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
@ -54,18 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer_2 = XLMTokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
self.assertListEqual(
|
||||
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
|
@ -15,28 +15,25 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP,
|
||||
SPIECE_UNDERLINE)
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fixtures/test_sentencepiece.model')
|
||||
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
|
||||
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
@ -44,11 +41,6 @@ class XLNetTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
vocab_path = u"/tmp/"
|
||||
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path,
|
||||
keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
@ -68,22 +60,6 @@ class XLNetTokenizationTest(unittest.TestCase):
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
text = "Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
|
||||
self.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
os.remove(filename)
|
||||
os.remove(vocab_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
|
@ -23,6 +23,7 @@ import unicodedata
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .model_utils import clean_up_tokenization
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -185,6 +186,19 @@ class BertTokenizer(object):
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, token_ids, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(token_ids)
|
||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
|
||||
out_string = out_string.replace(special_tok, '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
@ -198,7 +212,7 @@ class BertTokenizer(object):
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return vocab_file
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
|
@ -23,6 +23,8 @@ import os
|
||||
import regex as re
|
||||
from io import open
|
||||
|
||||
from .model_utils import clean_up_tokenization
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
@ -275,9 +277,7 @@ class GPT2Tokenizer(object):
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = text.replace('<unk>', '')
|
||||
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
@ -26,6 +26,7 @@ from io import open
|
||||
from tqdm import tqdm
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .model_utils import clean_up_tokenization
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -277,9 +278,7 @@ class OpenAIGPTTokenizer(object):
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
@ -31,6 +31,7 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .model_utils import clean_up_tokenization
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@ -109,6 +110,9 @@ class TransfoXLTokenizer(object):
|
||||
self.vocab_file = vocab_file
|
||||
self.never_split = never_split
|
||||
|
||||
if vocab_file is not None:
|
||||
self.build_vocab()
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
@ -155,7 +159,7 @@ class TransfoXLTokenizer(object):
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
torch.save(self.__dict__, vocab_file)
|
||||
return vocab_file
|
||||
return (vocab_file,)
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
@ -251,12 +255,20 @@ class TransfoXLTokenizer(object):
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||
|
||||
def decode(self, indices, exclude=None):
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of indices in a string."""
|
||||
if exclude is None:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices])
|
||||
out_string = ' '.join([self.get_sym(idx) for idx in indices])
|
||||
else:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
|
||||
return out_string
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
@ -26,6 +26,7 @@ from io import open
|
||||
from tqdm import tqdm
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .model_utils import clean_up_tokenization
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -285,9 +286,7 @@ class XLMTokenizer(object):
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
@ -27,6 +27,7 @@ import unicodedata
|
||||
import six
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .model_utils import clean_up_tokenization
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -316,9 +317,7 @@ class XLNetTokenizer(object):
|
||||
out_string = ''.join(tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.strip().replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
Loading…
Reference in New Issue
Block a user