mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00

* Implemented fast version of tokenizers
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Bumped tokenizers version requirements to latest 0.2.1
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Added matching tests
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Matching OpenAI GPT tokenization !
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Matching GPT2 on tokenizers
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Expose add_prefix_space as constructor parameter for GPT2
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Matching Roberta tokenization !
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Removed fast implementation of CTRL.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Binding TransformerXL tokenizers to Rust.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Updating tests accordingly.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Added tokenizers as top-level modules.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Black & isort.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Rename LookupTable to WordLevel to match Rust side.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Black.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Use "fast" suffix instead of "ru" for rust tokenizers implementations.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Introduce tokenize() method on fast tokenizers.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* encode_plus dispatchs to batch_encode_plus
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* batch_encode_plus now dispatchs to encode if there is only one input element.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Bind all the encode_plus parameter to the forwarded batch_encode_plus call.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Bump tokenizers dependency to 0.3.0
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Formatting.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Fix tokenization_auto with support for new (python, fast) mapping schema.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Give correct fixtures path in test_tokenization_fast.py for the CLI.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Expose max_len_ properties on BertTokenizerFast
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Move max_len_ properties to PreTrainedTokenizerFast and override in specific subclasses.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* _convert_encoding should keep the batch axis tensor if only one sample in the batch.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Add warning message for RobertaTokenizerFast if used for MLM.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Added use_fast (bool) parameter on AutoTokenizer.from_pretrained().
This allows to easily enable/disable Rust-based tokenizer instantiation.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Let's tokenizers handle all the truncation and padding stuff.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Allow to provide tokenizer arguments during pipeline creation.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Update test_fill_mask pipeline to not use fast tokenizers.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Fix too much parameters for convert_encoding.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* When enabling padding, max_length should be set to None.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Avoid returning nested tensors of length 1 when calling encode_plus
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Ensure output is padded when return_tensor is not None.
Tensor creation requires the inital list input to be of the exact same size.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Disable transfoxl unittest if pytorch is not available (required to load the model)
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* encode_plus should not remove the leading batch axis if return_tensor is set
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Temporary disable fast tokenizers on QA pipelines.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Fix formatting issues.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Update tokenizers to 0.4.0
* Update style
* Enable truncation + stride unit test on fast tokenizers.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Add unittest ensuring special_tokens set match between Python and Rust.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Ensure special_tokens are correctly set during construction.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Give more warning feedback to the user in case of padding without pad_token.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* quality & format.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Added possibility to add a single token as str
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Added unittest for add_tokens and add_special_tokens on fast tokenizers.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Fix rebase mismatch on pipelines qa default model.
QA requires cased input while the tokenizers would be uncased.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Using offset mapping relative to the original string + unittest.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: save_vocabulary requires folder and file name
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Simplify import for Bert.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: truncate_and_pad disables padding according to the same heuristic than the one enabling padding.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Remove private member access in tokenize()
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Bump tokenizers dependency to 0.4.2
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* format & quality.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Use named arguments when applicable.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Add Github link to Roberta/GPT2 space issue on masked input.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Move max_len_single_sentence / max_len_sentences_pair to PreTrainedTokenizerFast + tests.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Relax type checking to include tuple and list object.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Addressing review comment: Document the truncate_and_pad manager behavior.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Raise an exception if return_offsets_mapping is not available with the current tokenizer.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Ensure padding is set on the tokenizers before setting any padding strategy + unittest.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* On pytorch we need to stack tensor to get proper new axis.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Generalize tests to different framework removing hard written return_tensors="..."
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Bump tokenizer dependency for num_special_tokens_to_add
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Overflowing tokens in batch_encode_plus are now stacked over the batch axis.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Improved error message for padding strategy without pad token.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Bumping tokenizers dependency to 0.5.0 for release.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Optimizing convert_encoding around 4x improvement. 🚀
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* expose pad_to_max_length in encode_plus to avoid duplicating the parameters in kwargs
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Generate a proper overflow_to_sampling_mapping when return_overflowing_tokens is True.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Fix unittests for overflow_to_sampling_mapping not being returned as tensor.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Format & quality.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Remove perfect alignment constraint for Roberta (allowing 1% difference max)
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
* Triggering final CI
Co-authored-by: MOI Anthony <xn1t0x@gmail.com>
372 lines
18 KiB
Python
372 lines
18 KiB
Python
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from tests.utils import require_torch
|
|
from transformers import (
|
|
BertTokenizer,
|
|
BertTokenizerFast,
|
|
DistilBertTokenizer,
|
|
GPT2Tokenizer,
|
|
GPT2TokenizerFast,
|
|
OpenAIGPTTokenizer,
|
|
PreTrainedTokenizer,
|
|
RobertaTokenizer,
|
|
TransfoXLTokenizer,
|
|
is_torch_available,
|
|
)
|
|
from transformers.tokenization_distilbert import DistilBertTokenizerFast
|
|
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
|
|
from transformers.tokenization_roberta import RobertaTokenizerFast
|
|
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
|
|
|
|
|
|
class FastTokenizerMatchingTest(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
with open("tests/fixtures/sample_text.txt") as f_data:
|
|
self._data = f_data.read().replace("\n\n", "\n").strip()
|
|
|
|
def assert_sequence_almost_equals(self, a, b, threshold):
|
|
|
|
# Handle padding
|
|
if len(a) != len(b):
|
|
max_len = max(len(a), len(b))
|
|
|
|
# Pad with a negative number as vocab doesnt allow idx < 0
|
|
# if will be tracked as differences
|
|
if len(a) < max_len:
|
|
a += [-1] * (max_len - len(a))
|
|
|
|
if len(b) < max_len:
|
|
b += [-1] * (max_len - len(b))
|
|
|
|
# Convert to numpy for convenience
|
|
a_, b_ = np.array(a), np.array(b)
|
|
|
|
# Compute elementwise difference
|
|
inputs_diffs = a_ - b_
|
|
inputs_diff = np.count_nonzero(inputs_diffs)
|
|
self.assertLessEqual(inputs_diff / a_.shape[0], threshold)
|
|
|
|
def assert_tokenization_python_rust_almost_equals(self, tokenizer_p, tokenizer_r, threshold: float):
|
|
# Ensure basic input match
|
|
input_p = tokenizer_p.encode_plus(self._data)
|
|
input_r = tokenizer_r.encode_plus(self._data)
|
|
|
|
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
|
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
|
|
|
input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
|
|
input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)
|
|
|
|
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
|
self.assert_sequence_almost_equals(input_pairs_p[key], input_pairs_r[key], threshold)
|
|
|
|
# Ensure truncation match
|
|
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
|
|
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
|
|
|
|
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
|
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
|
|
|
# Ensure truncation with stride match
|
|
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
|
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
|
|
|
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
|
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
|
|
|
|
def assert_add_tokens(self, tokenizer_r):
|
|
vocab_size = tokenizer_r.vocab_size
|
|
self.assertEqual(tokenizer_r.add_tokens(""), 0)
|
|
self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
|
|
self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
|
|
self.assertEqual(len(tokenizer_r), vocab_size + 3)
|
|
|
|
self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
|
|
self.assertRaises(
|
|
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
|
|
)
|
|
self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
|
|
self.assertEqual(
|
|
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
|
|
)
|
|
self.assertEqual(len(tokenizer_r), vocab_size + 6)
|
|
|
|
def assert_offsets_mapping(self, tokenizer):
|
|
text = "Wonderful no inspiration example with subtoken"
|
|
pair = "Along with an awesome pair"
|
|
|
|
# No pair
|
|
tokens_with_offsets = tokenizer.encode_plus(text, return_special_tokens_mask=True, return_offsets_mapping=True)
|
|
added_tokens = tokenizer.num_added_tokens(False)
|
|
offsets = tokens_with_offsets["offset_mapping"]
|
|
|
|
# Assert there is the same number of tokens and offsets
|
|
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
|
|
|
|
# Assert there is online added_tokens special_tokens
|
|
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
|
|
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
|
|
|
# Pairs
|
|
tokens_with_offsets = tokenizer.encode_plus(
|
|
text, pair, return_special_tokens_mask=True, return_offsets_mapping=True
|
|
)
|
|
added_tokens = tokenizer.num_added_tokens(True)
|
|
offsets = tokens_with_offsets["offset_mapping"]
|
|
|
|
# Assert there is the same number of tokens and offsets
|
|
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
|
|
|
|
# Assert there is online added_tokens special_tokens
|
|
self.assertEqual(sum([0 if x else 1 for x in offsets]), added_tokens)
|
|
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
|
|
|
def assert_batch_encode_dynamic_overflowing(self, tokenizer: PreTrainedTokenizer):
|
|
"""
|
|
When calling batch_encode with multiple sequence it can returns different number of
|
|
overflowing encoding for each sequence:
|
|
[
|
|
Sequence 1: [Encoding 1, Encoding 2],
|
|
Sequence 2: [Encoding 1],
|
|
Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
|
|
]
|
|
This needs to be padded so that it can represented as a tensor
|
|
"""
|
|
returned_tensor = "pt" if is_torch_available() else "tf"
|
|
|
|
tokens = tokenizer.encode_plus(
|
|
"HuggingFace is solving NLP one commit at a time",
|
|
max_length=6,
|
|
return_tensors=returned_tensor,
|
|
return_overflowing_tokens=True,
|
|
)
|
|
|
|
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
|
|
self.assertEqual(len(tokens[key].shape), 2)
|
|
|
|
# Mono sample
|
|
tokens = tokenizer.batch_encode_plus(
|
|
["HuggingFace is solving NLP one commit at a time"],
|
|
max_length=6,
|
|
pad_to_max_len=True,
|
|
return_tensors=returned_tensor,
|
|
return_overflowing_tokens=True,
|
|
)
|
|
|
|
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
|
|
self.assertEqual(len(tokens[key].shape), 2)
|
|
self.assertEqual(tokens[key].shape[-1], 6)
|
|
|
|
# Multi sample
|
|
tokens = tokenizer.batch_encode_plus(
|
|
["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
|
|
max_length=6,
|
|
pad_to_max_len=True,
|
|
return_tensors=returned_tensor,
|
|
return_overflowing_tokens=True,
|
|
)
|
|
|
|
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
|
|
self.assertEqual(len(tokens[key].shape), 2)
|
|
self.assertEqual(tokens[key].shape[-1], 6)
|
|
|
|
def test_bert(self):
|
|
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = BertTokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"Bert tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
@require_torch
|
|
def test_transfoxl(self):
|
|
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
|
tokenizer_p = TransfoXLTokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = TransfoXLTokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"TransfoXL tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
def test_distilbert(self):
|
|
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = DistilBertTokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# DistilBert should match 100%
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"DistilBert tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
def test_gpt2(self):
|
|
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = GPT2TokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"GPT2 tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
def test_roberta(self):
|
|
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = RobertaTokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"Roberta tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.01)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
def test_openai(self):
|
|
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
|
tokenizer_r = OpenAIGPTTokenizerFast.from_pretrained(tokenizer_name)
|
|
|
|
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.num_added_tokens(False), tokenizer_p.num_added_tokens(False))
|
|
self.assertEqual(tokenizer_r.num_added_tokens(True), tokenizer_p.num_added_tokens(True))
|
|
|
|
# Check we have the correct max_length for both pair and non-pair inputs.
|
|
self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
|
|
self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
|
|
|
|
# Assert the set of special tokens match.
|
|
self.assertSequenceEqual(
|
|
tokenizer_p.special_tokens_map.items(),
|
|
tokenizer_r.special_tokens_map.items(),
|
|
"GPT tokenizers doesn't have the same set of special_tokens",
|
|
)
|
|
|
|
# Assure tokenization overlap between python and rust impl.
|
|
self.assert_tokenization_python_rust_almost_equals(tokenizer_p, tokenizer_r, 0.0)
|
|
|
|
# Ensure add_tokens and add_special_tokens return the correct vocab size
|
|
self.assert_add_tokens(tokenizer_r)
|
|
|
|
# Check for offsets mapping
|
|
self.assert_offsets_mapping(tokenizer_r)
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|