mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
allow unused input parameters passthrough when chunking in asr pipelines (#33889)
* allow unused parameter passthrough when chunking in asr pipelines * format code * format * run fixup * update tests * update parameters to pipline in test * updates parametrs in tests * change spelling in gitignore * revert .gitignore to main * add git ignore of devcontainer folder * assert asr output follows expected inference output type * run fixup * Remove .devcontainer from .gitignore * remove compliance check
This commit is contained in:
parent
4dc1a69349
commit
a0f4f3174f
@ -434,7 +434,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
for item in chunk_iter(
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||
):
|
||||
yield item
|
||||
yield {**item, **extra}
|
||||
else:
|
||||
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
||||
processed = self.feature_extractor(
|
||||
|
@ -1443,6 +1443,25 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||
|
||||
@require_torch
|
||||
def test_input_parameter_passthrough(self):
|
||||
"""Test that chunked vs non chunked versions of ASR pipelines returns the same structure for the same inputs."""
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
inputs = {"raw": audio, "sampling_rate": 16_000, "id": 1}
|
||||
|
||||
chunked_output = speech_recognizer(inputs.copy(), chunk_length_s=30)
|
||||
non_chunked_output = speech_recognizer(inputs.copy())
|
||||
assert (
|
||||
chunked_output.keys() == non_chunked_output.keys()
|
||||
), "The output structure should be the same for chunked vs non-chunked versions of asr pipelines."
|
||||
|
||||
@require_torch
|
||||
def test_return_timestamps_ctc_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
|
Loading…
Reference in New Issue
Block a user