allow unused input parameters passthrough when chunking in asr pipelines (#33889)

* allow unused parameter passthrough when chunking in asr pipelines

* format code

* format

* run fixup

* update tests

* update parameters to pipline in test

* updates parametrs in tests

* change spelling in gitignore

* revert .gitignore to main

* add git ignore of devcontainer folder

* assert asr output follows expected inference output type

* run fixup

* Remove .devcontainer from .gitignore

* remove compliance check
This commit is contained in:
VictorAtIfInsurance 2024-11-25 11:36:44 +01:00 committed by GitHub
parent 4dc1a69349
commit a0f4f3174f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -434,7 +434,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
for item in chunk_iter( for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
): ):
yield item yield {**item, **extra}
else: else:
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples: if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor( processed = self.feature_extractor(

View File

@ -1443,6 +1443,25 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC") self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
def test_input_parameter_passthrough(self):
"""Test that chunked vs non chunked versions of ASR pipelines returns the same structure for the same inputs."""
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
inputs = {"raw": audio, "sampling_rate": 16_000, "id": 1}
chunked_output = speech_recognizer(inputs.copy(), chunk_length_s=30)
non_chunked_output = speech_recognizer(inputs.copy())
assert (
chunked_output.keys() == non_chunked_output.keys()
), "The output structure should be the same for chunked vs non-chunked versions of asr pipelines."
@require_torch @require_torch
def test_return_timestamps_ctc_fast(self): def test_return_timestamps_ctc_fast(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(