mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Improve error messaging for ASR pipeline. (#19570)
* Improve error messaging for ASR pipeline. - Raise error early (in `_sanitize`) so users don't waste time trying to run queries with invalid params. - Fix the error was after using `config.inputs_to_logits_ratio` so our check was masked by the failing property does not exist. - Added some manual check on s2t for the error message. No non ctc model seems to be used by the default runner (they are all skipped). * Removing pdb. * Stop the early error it doesn't really work :(.
This commit is contained in:
parent
5ef2186692
commit
463226e2ee
@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||
|
||||
if chunk_length_s:
|
||||
if self.type not in {"ctc", "ctc_with_lm"}:
|
||||
raise ValueError(
|
||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
|
||||
)
|
||||
if stride_length_s is None:
|
||||
stride_length_s = chunk_length_s / 6
|
||||
|
||||
@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
||||
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
||||
|
||||
if self.type not in {"ctc", "ctc_with_lm"}:
|
||||
raise ValueError(
|
||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
|
||||
)
|
||||
if chunk_len < stride_left + stride_right:
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
|
@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Non CTC models cannot use chunk_length
|
||||
with self.assertRaises(ValueError) as v:
|
||||
outputs = speech_recognizer(audio, chunk_length_s=10)
|
||||
self.assertEqual(v.exception, "")
|
||||
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(ValueError) as v:
|
||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||
self.assertEqual(v.exception, "")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||
with self.assertRaises(ValueError) as v:
|
||||
_ = speech_recognizer(waveform, chunk_length_s=10)
|
||||
self.assertEqual(
|
||||
str(v.exception),
|
||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
|
||||
)
|
||||
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaises(ValueError) as v:
|
||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||
self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !")
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_seq2seq(self):
|
||||
|
Loading…
Reference in New Issue
Block a user