🚨🚨[Whisper Tok] Update integration test (#29368)

* [Whisper Tok] Update integration test

* make style
This commit is contained in:
Sanchit Gandhi 2024-03-01 09:22:31 +00:00 committed by GitHub
parent e7b9837065
commit 0a0a279e99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,7 +16,7 @@ import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import require_jinja, slow
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
@ -67,26 +67,26 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġ", "test"])
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġtest"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[5723, 307, 257, 220, 31636],
[5723, 307, 257, 1500],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13])
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 341, 307, 16720, 526, 13])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)
def test_tokenizer_slow_store_full_signature(self):
pass
@ -499,25 +499,3 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])
@require_jinja
def test_tokenization_for_chat(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)