mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00

* [WIP] SP tokenizers * fixing tests for T5 * WIP tokenizers * serialization * update T5 * WIP T5 tokenization * slow to fast conversion script * Refactoring to move tokenzier implementations inside transformers * Adding gpt - refactoring - quality * WIP adding several tokenizers to the fast world * WIP Roberta - moving implementations * update to dev4 switch file loading to in-memory loading * Updating and fixing * advancing on the tokenizers - updating do_lower_case * style and quality * moving forward with tokenizers conversion and tests * MBart, T5 * dumping the fast version of transformer XL * Adding to autotokenizers + style/quality * update init and space_between_special_tokens * style and quality * bump up tokenizers version * add protobuf * fix pickle Bert JP with Mecab * fix newly added tokenizers * style and quality * fix bert japanese * fix funnel * limite tokenizer warning to one occurence * clean up file * fix new tokenizers * fast tokenizers deep tests * WIP adding all the special fast tests on the new fast tokenizers * quick fix * adding more fast tokenizers in the fast tests * all tokenizers in fast version tested * Adding BertGenerationFast * bump up setup.py for CI * remove BertGenerationFast (too early) * bump up tokenizers version * Clean old docstrings * Typo * Update following Lysandre comments Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
import unittest
|
|
from pathlib import Path
|
|
|
|
from transformers.file_utils import cached_property
|
|
from transformers.testing_utils import require_torch
|
|
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
|
|
|
|
from .test_tokenization_common import TokenizerTesterMixin
|
|
|
|
|
|
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|
|
|
tokenizer_class = PegasusTokenizer
|
|
rust_tokenizer_class = PegasusTokenizerFast
|
|
test_rust_tokenizer = True
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
save_dir = Path(self.tmpdirname)
|
|
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
|
if not (save_dir / spm_file).exists():
|
|
tokenizer = self.pegasus_large_tokenizer
|
|
tokenizer.save_pretrained(self.tmpdirname)
|
|
|
|
@cached_property
|
|
def pegasus_large_tokenizer(self):
|
|
return PegasusTokenizer.from_pretrained("google/pegasus-large")
|
|
|
|
@unittest.skip("add_tokens does not work yet")
|
|
def test_swap_special_token(self):
|
|
pass
|
|
|
|
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
|
if not kwargs:
|
|
return self.pegasus_large_tokenizer
|
|
else:
|
|
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
|
|
|
def get_input_output_texts(self, tokenizer):
|
|
return ("This is a test", "This is a test")
|
|
|
|
def test_pegasus_large_tokenizer_settings(self):
|
|
tokenizer = self.pegasus_large_tokenizer
|
|
# The tracebacks for the following asserts are **better** without messages or self.assertEqual
|
|
assert tokenizer.vocab_size == 96103
|
|
assert tokenizer.pad_token_id == 0
|
|
assert tokenizer.eos_token_id == 1
|
|
assert tokenizer.offset == 103
|
|
assert tokenizer.unk_token_id == tokenizer.offset + 2 == 105
|
|
assert tokenizer.unk_token == "<unk>"
|
|
assert tokenizer.mask_token is None
|
|
assert tokenizer.mask_token_id is None
|
|
assert tokenizer.model_max_length == 1024
|
|
raw_input_str = "To ensure a smooth flow of bank resolutions."
|
|
desired_result = [413, 615, 114, 2291, 1971, 113, 1679, 10710, 107, 1]
|
|
ids = tokenizer([raw_input_str], return_tensors=None).input_ids[0]
|
|
self.assertListEqual(desired_result, ids)
|
|
assert tokenizer.convert_ids_to_tokens([0, 1, 2]) == ["<pad>", "</s>", "unk_2"]
|
|
|
|
@require_torch
|
|
def test_pegasus_large_seq2seq_truncation(self):
|
|
src_texts = ["This is going to be way too long" * 10000, "short example"]
|
|
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
|
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
|
assert batch.input_ids.shape == (2, 1024)
|
|
assert batch.attention_mask.shape == (2, 1024)
|
|
assert "labels" in batch # because tgt_texts was specified
|
|
assert batch.labels.shape == (2, 5)
|
|
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
|