Improve error messaging for ASR pipeline. (#19570)

* Improve error messaging for ASR pipeline.

- Raise error early (in `_sanitize`) so users don't waste time trying to
  run queries with invalid params.

- Fix the error was after using `config.inputs_to_logits_ratio` so our
  check was masked by the failing property does not exist.

- Added some manual check on s2t for the error message.
  No non ctc model seems to be used by the default runner (they are all
  skipped).

* Removing pdb.

* Stop the early error it doesn't really work :(.
This commit is contained in:
Nicolas Patry 2022-10-14 17:12:21 +02:00 committed by GitHub
parent 5ef2186692
commit 463226e2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 5 deletions

View File

@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")

View File

@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
},
)
else:
# Non CTC models cannot use chunk_length
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, chunk_length_s=10)
self.assertEqual(v.exception, "")
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError):
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(v.exception, "")
@require_torch
@slow
@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, return_timestamps="char")
self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !")
@require_torch
def test_small_model_pt_seq2seq(self):