transformers/tests/test_tokenization_fast.py
Funtowicz Morgan 3f3fa7f7da
Integrate fast tokenizers library inside transformers (#2674)
* Implemented fast version of tokenizers

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bumped tokenizers version requirements to latest 0.2.1

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added matching tests

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Matching OpenAI GPT tokenization !

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Matching GPT2 on tokenizers

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Expose add_prefix_space as constructor parameter for GPT2

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Matching Roberta tokenization !

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Removed fast implementation of CTRL.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Binding TransformerXL tokenizers to Rust.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Updating tests accordingly.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added tokenizers as top-level modules.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Black & isort.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Rename LookupTable to WordLevel to match Rust side.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Black.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Use "fast" suffix instead of "ru" for rust tokenizers implementations.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Introduce tokenize() method on fast tokenizers.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* encode_plus dispatchs to batch_encode_plus

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* batch_encode_plus now dispatchs to encode if there is only one input element.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bind all the encode_plus parameter to the forwarded batch_encode_plus call.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bump tokenizers dependency to 0.3.0

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Formatting.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix tokenization_auto with support for new (python, fast) mapping schema.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Give correct fixtures path in test_tokenization_fast.py for the CLI.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Expose max_len_ properties on BertTokenizerFast

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Move max_len_ properties to PreTrainedTokenizerFast and override in specific subclasses.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* _convert_encoding should keep the batch axis tensor if only one sample in the batch.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Add warning message for RobertaTokenizerFast if used for MLM.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added use_fast (bool) parameter on AutoTokenizer.from_pretrained().

This allows to easily enable/disable Rust-based tokenizer instantiation.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Let's tokenizers handle all the truncation and padding stuff.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Allow to provide tokenizer arguments during pipeline creation.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Update test_fill_mask pipeline to not use fast tokenizers.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix too much parameters for convert_encoding.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* When enabling padding, max_length should be set to None.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Avoid returning nested tensors of length 1 when calling encode_plus

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Ensure output is padded when return_tensor is not None.

Tensor creation requires the inital list input to be of the exact same size.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Disable transfoxl unittest if pytorch is not available (required to load the model)

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* encode_plus should not remove the leading batch axis if return_tensor is set

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Temporary disable fast tokenizers on QA pipelines.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix formatting issues.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Update tokenizers to 0.4.0

* Update style

* Enable truncation + stride unit test on fast tokenizers.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Add unittest ensuring special_tokens set match between Python and Rust.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Ensure special_tokens are correctly set during construction.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Give more warning feedback to the user in case of padding without pad_token.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* quality & format.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added possibility to add a single token as str

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added unittest for add_tokens and add_special_tokens on fast tokenizers.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix rebase mismatch on pipelines qa default model.

QA requires cased input while the tokenizers would be uncased.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Using offset mapping relative to the original string + unittest.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: save_vocabulary requires folder and file name

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Simplify import for Bert.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: truncate_and_pad disables padding according to the same heuristic than the one enabling padding.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Remove private member access in tokenize()

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Bump tokenizers dependency to 0.4.2

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* format & quality.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Use named arguments when applicable.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Add Github link to Roberta/GPT2 space issue on masked input.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Move max_len_single_sentence / max_len_sentences_pair to PreTrainedTokenizerFast + tests.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Relax type checking to include tuple and list object.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing review comment: Document the truncate_and_pad manager behavior.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Raise an exception if return_offsets_mapping is not available with the current tokenizer.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Ensure padding is set on the tokenizers before setting any padding strategy + unittest.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* On pytorch we need to stack tensor to get proper new axis.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Generalize tests to different framework removing hard written return_tensors="..."

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bump tokenizer dependency for num_special_tokens_to_add

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Overflowing tokens in batch_encode_plus are now stacked over the batch axis.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Improved error message for padding strategy without pad token.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bumping tokenizers dependency to 0.5.0 for release.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Optimizing convert_encoding around 4x improvement. 🚀

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* expose pad_to_max_length in encode_plus to avoid duplicating the parameters in kwargs

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Generate a proper overflow_to_sampling_mapping when return_overflowing_tokens is True.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix unittests for overflow_to_sampling_mapping not being returned as tensor.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Format & quality.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Remove perfect alignment constraint for Roberta (allowing 1% difference max)

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Triggering final CI

Co-authored-by: MOI Anthony <xn1t0x@gmail.com>
2020-02-19 11:35:40 -05:00

372 lines
18 KiB
Python

import unittest
import numpy as np
from tests.utils import require_torch
from transformers import (
BertTokenizer,
BertTokenizerFast,
DistilBertTokenizer,
GPT2Tokenizer,
GPT2TokenizerFast,
OpenAIGPTTokenizer,
PreTrainedTokenizer,
RobertaTokenizer,
TransfoXLTokenizer,
is_torch_available,
)
from transformers.tokenization_distilbert import DistilBertTokenizerFast
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
class FastTokenizerMatchingTest(unittest.TestCase):
def setUp(self) -> None:
with open("tests/fixtures/sample_text.txt") as f_data:
self._data = f_data.read().replace("\n\n", "\n").strip()
def assert_sequence_almost_equals(self, a, b, threshold):
# Handle padding
if len(a) != len(b):
max_len = max(len(a), len(b))
# Pad with a negative number as vocab doesnt allow idx < 0
# if will be tracked as differences
if len(a) < max_len:
a += [-1] * (max_len - len(a))
if len(b) < max_len:
b += [-1] * (max_len - len(b))
# Convert to numpy for convenience
a_, b_ = np.array(a), np.array(b)
# Compute elementwise difference
inputs_diffs = a_ - b_
inputs_diff = np.count_nonzero(inputs_diffs)
self.assertLessEqual(inputs_diff / a_.shape[0], threshold)
def assert_tokenization_python_rust_almost_equals(self, tokenizer_p, tokenizer_r, threshold: float):
# Ensure basic input match
input_p = tokenizer_p.encode_plus(self._data)
input_r = tokenizer_r.encode_plus(self._data)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_pairs_p[key], input_pairs_r[key], threshold)
# Ensure truncation match
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
# Ensure truncation with stride match
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
def assert_add_tokens(self, tokenizer_r):
vocab_size = tokenizer_r.vocab_size
self.assertEqual(tokenizer_r.add_tokens(""), 0)
self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
self.assertEqual(len(tokenizer_r), vocab_size + 3)
self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
self.assertRaises(
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
)
self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
self.assertEqual(
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
)
self.assertEqual(len(tokenizer_r), vocab_size + 6)
def assert_offsets_mapping(self, tokenizer):
text = "Wonderful no inspiration example with subtoken"
pair = "Along with an awesome pair"
# No pair
tokens_with_offsets = tokenizer.encode_plus(text, return_special_tokens_mask=True, return_offsets_mapping=True)
added_tokens = tokenizer.num_added_tokens(False)
offsets = tokens_with_offsets["offset_mapping"]
# Assert there is the same number of tokens and offsets
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
# Assert there is online added_tokens special_tokens
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
# Pairs
tokens_with_offsets = tokenizer.encode_plus(
text, pair, return_special_tokens_mask=True, return_offsets_mapping=True
)
added_tokens = tokenizer.num_added_tokens(True)
offsets = tokens_with_offsets["offset_mapping"]
# Assert there is the same number of tokens and offsets
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
# Assert there is online added_tokens special_tokens
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
def assert_batch_encode_dynamic_overflowing(self, tokenizer: PreTrainedTokenizer):
"""
When calling batch_encode with multiple sequence it can returns different number of
overflowing encoding for each sequence:
[
Sequence 1: [Encoding 1, Encoding 2],
Sequence 2: [Encoding 1],
Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
]
This needs to be padded so that it can represented as a tensor
"""
returned_tensor = "pt" if is_torch_available() else "tf"
tokens = tokenizer.encode_plus(
"HuggingFace is solving NLP one commit at a time",
max_length=6,
return_tensors=returned_tensor,
return_overflowing_tokens=True,
)
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
self.assertEqual(len(tokens[key].shape), 2)
# Mono sample
tokens = tokenizer.batch_encode_plus(
["HuggingFace is solving NLP one commit at a time"],
max_length=6,
pad_to_max_len=True,
return_tensors=returned_tensor,
return_overflowing_tokens=True,
)
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
self.assertEqual(len(tokens[key].shape), 2)
self.assertEqual(tokens[key].shape[-1], 6)
# Multi sample
tokens = tokenizer.batch_encode_plus(
["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
max_length=6,
pad_to_max_len=True,
return_tensors=returned_tensor,
return_overflowing_tokens=True,
)
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
self.assertEqual(len(tokens[key].shape), 2)
self.assertEqual(tokens[key].shape[-1], 6)
def test_bert(self):
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = BertTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"Bert tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
@require_torch
def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
tokenizer_p = TransfoXLTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = TransfoXLTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"TransfoXL tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = DistilBertTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# DistilBert should match 100%
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"DistilBert tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
tokenizer_r = GPT2TokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"GPT2 tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = RobertaTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"Roberta tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.01)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
tokenizer_r = OpenAIGPTTokenizerFast.from_pretrained(tokenizer_name)
# Check we have the same number of added_tokens for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
# Check we have the correct max_length for both pair and non-pair inputs.
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
"GPT tokenizers doesn't have the same set of special_tokens",
)
# Assure tokenization overlap between python and rust impl.
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
# Ensure add_tokens and add_special_tokens return the correct vocab size
self.assert_add_tokens(tokenizer_r)
# Check for offsets mapping
self.assert_offsets_mapping(tokenizer_r)
# Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
if __name__ == "__main__":
unittest.main()