mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
added code to raise value error for bert tokenizer for covert_tokens_to_indices
This commit is contained in:
parent
786cc41299
commit
78cf7b4ab4
@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'bert-base-uncased': 512,
|
||||
'bert-large-uncased': 512,
|
||||
'bert-base-cased': 512,
|
||||
'bert-large-cased': 512,
|
||||
'bert-base-multilingual-uncased': 512,
|
||||
'bert-base-multilingual-cased': 512,
|
||||
'bert-base-chinese': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.txt'
|
||||
|
||||
|
||||
@ -65,7 +74,8 @@ def whitespace_tokenize(text):
|
||||
|
||||
class BertTokenizer(object):
|
||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
@ -75,6 +85,7 @@ class BertTokenizer(object):
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
@ -88,6 +99,12 @@ class BertTokenizer(object):
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.vocab[token])
|
||||
if len(ids) > self.max_len:
|
||||
raise ValueError(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this BERT model ({} > {}). Running this"
|
||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
@ -126,6 +143,11 @@ class BertTokenizer(object):
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
@ -193,7 +215,7 @@ class BasicTokenizer(object):
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
@ -218,17 +240,17 @@ class BasicTokenizer(object):
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
|
@ -44,12 +44,30 @@ class TokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
def test_full_tokenizer_raises_error_for_long_sequences(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file, max_len=10)
|
||||
os.remove(vocab_file)
|
||||
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time")
|
||||
indices = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(indices, [0 for _ in range(10)])
|
||||
|
||||
tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .")
|
||||
self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
||||
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True)
|
||||
|
Loading…
Reference in New Issue
Block a user