mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix flax whisper tokenizer bug (#33151)
* Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
516ee6adc2
commit
8ed635258c
@ -880,6 +880,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
token_ids = token_ids.cpu().numpy()
|
token_ids = token_ids.cpu().numpy()
|
||||||
elif "tensorflow" in str(type(token_ids)):
|
elif "tensorflow" in str(type(token_ids)):
|
||||||
token_ids = token_ids.numpy()
|
token_ids = token_ids.numpy()
|
||||||
|
elif "jaxlib" in str(type(token_ids)):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
# now the token ids are either a numpy array, or a list of lists
|
# now the token ids are either a numpy array, or a list of lists
|
||||||
if isinstance(token_ids, np.ndarray):
|
if isinstance(token_ids, np.ndarray):
|
||||||
token_ids = token_ids.tolist()
|
token_ids = token_ids.tolist()
|
||||||
|
@ -613,6 +613,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
token_ids = token_ids.cpu().numpy()
|
token_ids = token_ids.cpu().numpy()
|
||||||
elif "tensorflow" in str(type(token_ids)):
|
elif "tensorflow" in str(type(token_ids)):
|
||||||
token_ids = token_ids.numpy()
|
token_ids = token_ids.numpy()
|
||||||
|
elif "jaxlib" in str(type(token_ids)):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
# now the token ids are either a numpy array, or a list of lists
|
# now the token ids are either a numpy array, or a list of lists
|
||||||
if isinstance(token_ids, np.ndarray):
|
if isinstance(token_ids, np.ndarray):
|
||||||
token_ids = token_ids.tolist()
|
token_ids = token_ids.tolist()
|
||||||
|
@ -18,7 +18,7 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import require_flax, require_tf, require_torch, slow
|
||||||
|
|
||||||
from ...test_tokenization_common import TokenizerTesterMixin
|
from ...test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
@ -574,3 +574,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||||
self.assertEqual(output, [])
|
self.assertEqual(output, [])
|
||||||
|
|
||||||
|
def test_convert_to_list_np(self):
|
||||||
|
test_list = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
|
||||||
|
# Test with an already converted list
|
||||||
|
self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list)
|
||||||
|
|
||||||
|
# Test with a numpy array
|
||||||
|
np_array = np.array(test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_convert_to_list_tf(self):
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
test_list = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
tf_tensor = tf.constant(test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list)
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_convert_to_list_jax(self):
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
test_list = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
jax_array = jnp.array(test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_convert_to_list_pt(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
test_list = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
torch_tensor = torch.tensor(test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list)
|
||||||
|
self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)
|
||||||
|
Loading…
Reference in New Issue
Block a user