transformers/tests/test_tokenization_clip.py
SaulLu 476ba679dd
Feature to use the PreTrainedTokenizerFast class as a stand-alone tokenizer (#11810)
* feature for tokenizer without slow/legacy version

* format

* modify common test

* add tests

* add PreTrainedTokenizerFast to AutoTokenizer

* format

* change tokenizer common test in order to be able to run test without a slow version

* update tokenizer fast test in order to use `rust_tokenizer_class` attribute instead of `tokenizer_class`

* add autokenizer test

* replace  `if self.tokenizer_class is not None` with ` if self.tokenizer_class is None`

* remove obsolete change in comment

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/tokenization_utils_fast.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change `get_main_tokenizer` into `get_tokenizers`

* clarify `get_tokenizers` method

* homogenize with `test_slow_tokenizer` and `test_rust_tokenizer`

* add `test_rust_tokenizer = False` to tokenizer which don't define a fast version

* `test_rust_tokenizer = False` for BertJapaneseTokenizer

* `test_rust_tokenizer = False` for BertJapaneseCharacterTokenizationTest

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2021-06-14 11:58:44 +02:00

209 lines
8.9 KiB
Python

# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import unittest
from transformers import CLIPTokenizer, CLIPTokenizerFast
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CLIPTokenizer
rust_tokenizer_class = CLIPTokenizerFast
test_rust_tokenizer = False
from_pretrained_kwargs = {"add_prefix_space": True}
test_seq2seq = False
def setUp(self):
super().setUp()
# fmt: off
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|endoftext|>"]
# fmt: on
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return CLIPTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return CLIPTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer "
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = CLIPTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["lo", "w", "er</w>", "n", "e", "w", "er</w>"]
tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [10, 2, 12, 9, 3, 2, 12, 16]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
sequence = "lower newer"
# Testing tokenization
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
# Testing conversion to ids without special tokens
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
# Testing conversion to ids with special tokens
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
ids = tokenizer.encode(sequence, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing the unknown token
input_tokens = tokens + [rust_tokenizer.unk_token]
input_bpe_tokens = [10, 2, 12, 9, 3, 2, 12, 16]
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_pretokenized_inputs(self, *args, **kwargs):
# It's very difficult to mix/test pretokenization with byte-level
# And get both CLIP and Roberta to work at the same time (mostly an issue of adding a space before the string)
pass
def test_padding(self, max_length=15):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
# Simple input
s = "This is a simple input"
s2 = ["This is a simple input 1", "This is a simple input 2"]
p = ("This is a simple input", "This is a pair")
p2 = [
("This is a simple input 1", "This is a simple input 2"),
("This is a simple pair 1", "This is a simple pair 2"),
]
# Simple input tests
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
# Simple input
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
# Simple input
self.assertRaises(
ValueError,
tokenizer_r.batch_encode_plus,
s2,
max_length=max_length,
padding="max_length",
)
# Pair input
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
# Pair input
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
# Pair input
self.assertRaises(
ValueError,
tokenizer_r.batch_encode_plus,
p2,
max_length=max_length,
padding="max_length",
)
def test_add_tokens_tokenizer(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
self.assertNotEqual(vocab_size_2, 0)
self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokens[-3])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
# padding is very hacky in CLIPTokenizer, pad_token_id is always 0
# so skip this check
# self.assertEqual(tokens[-2], tokenizer.pad_token_id)