added code to raise value error for bert tokenizer for covert_tokens_to_indices

This commit is contained in:
Patrick Lewis 2018-12-18 14:41:30 +00:00
parent 786cc41299
commit 78cf7b4ab4
2 changed files with 53 additions and 13 deletions

View File

@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
'bert-large-cased': 512,
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
}
VOCAB_NAME = 'vocab.txt'
@ -65,7 +74,8 @@ def whitespace_tokenize(text):
class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True):
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
@ -75,6 +85,7 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
split_tokens = []
@ -88,6 +99,12 @@ class BertTokenizer(object):
ids = []
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
@ -126,6 +143,11 @@ class BertTokenizer(object):
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
@ -193,7 +215,7 @@ class BasicTokenizer(object):
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
@ -218,17 +240,17 @@ class BasicTokenizer(object):
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []

View File

@ -44,12 +44,30 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_full_tokenizer_raises_error_for_long_sequences(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
]
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
tokenizer = BertTokenizer(vocab_file, max_len=10)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
indices = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(indices, [0 for _ in range(10)])
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
def test_chinese(self):
tokenizer = BasicTokenizer()
self.assertListEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = BasicTokenizer(do_lower_case=True)