Fix flax whisper tokenizer bug (#33151)

* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Hannan Komari 2024-09-12 14:51:59 +03:30 committed by GitHub
parent 516ee6adc2
commit 8ed635258c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 1 deletions

View File

@ -880,6 +880,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()

View File

@ -613,6 +613,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()

View File

@ -18,7 +18,7 @@ import numpy as np
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from transformers.testing_utils import require_flax, require_tf, require_torch, slow
from ...test_tokenization_common import TokenizerTesterMixin
@ -574,3 +574,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])
def test_convert_to_list_np(self):
test_list = [[1, 2, 3], [4, 5, 6]]
# Test with an already converted list
self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list)
# Test with a numpy array
np_array = np.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)
@require_tf
def test_convert_to_list_tf(self):
import tensorflow as tf
test_list = [[1, 2, 3], [4, 5, 6]]
tf_tensor = tf.constant(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list)
@require_flax
def test_convert_to_list_jax(self):
import jax.numpy as jnp
test_list = [[1, 2, 3], [4, 5, 6]]
jax_array = jnp.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)
@require_torch
def test_convert_to_list_pt(self):
import torch
test_list = [[1, 2, 3], [4, 5, 6]]
torch_tensor = torch.tensor(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)