mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Whisper] Fix decoder ids methods (#20599)
* [Whisper] Fix decoder ids methods * enum property
This commit is contained in:
parent
ef0f85cd57
commit
74fb524e20
@ -583,5 +583,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
||||||
self.set_prefix_tokens(task=task, language=language, predict_timestamps=no_timestamps)
|
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
|
||||||
return self.prefix_tokens
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)]
|
||||||
|
return forced_decoder_ids
|
||||||
|
@ -26,6 +26,11 @@ if is_speech_available():
|
|||||||
from transformers import WhisperFeatureExtractor, WhisperProcessor
|
from transformers import WhisperFeatureExtractor, WhisperProcessor
|
||||||
|
|
||||||
|
|
||||||
|
START_OF_TRANSCRIPT = 50257
|
||||||
|
TRANSCRIBE = 50358
|
||||||
|
NOTIMESTAMPS = 50362
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@ -128,3 +133,17 @@ class WhisperProcessorTest(unittest.TestCase):
|
|||||||
feature_extractor.model_input_names,
|
feature_extractor.model_input_names,
|
||||||
msg="`processor` and `feature_extractor` model input names do not match",
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_decoder_prompt_ids(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", no_timestamps=True)
|
||||||
|
|
||||||
|
self.assertIsInstance(forced_decoder_ids, list)
|
||||||
|
for ids in forced_decoder_ids:
|
||||||
|
self.assertIsInstance(ids, (list, tuple))
|
||||||
|
|
||||||
|
expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS]
|
||||||
|
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||||
|
Loading…
Reference in New Issue
Block a user