mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
allow unused input parameters passthrough when chunking in asr pipelines (#33889)
* allow unused parameter passthrough when chunking in asr pipelines * format code * format * run fixup * update tests * update parameters to pipline in test * updates parametrs in tests * change spelling in gitignore * revert .gitignore to main * add git ignore of devcontainer folder * assert asr output follows expected inference output type * run fixup * Remove .devcontainer from .gitignore * remove compliance check
This commit is contained in:
parent
4dc1a69349
commit
a0f4f3174f
@ -434,7 +434,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
for item in chunk_iter(
|
for item in chunk_iter(
|
||||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||||
):
|
):
|
||||||
yield item
|
yield {**item, **extra}
|
||||||
else:
|
else:
|
||||||
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
||||||
processed = self.feature_extractor(
|
processed = self.feature_extractor(
|
||||||
|
@ -1443,6 +1443,25 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(output, [{"text": ANY(str)}])
|
self.assertEqual(output, [{"text": ANY(str)}])
|
||||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_input_parameter_passthrough(self):
|
||||||
|
"""Test that chunked vs non chunked versions of ASR pipelines returns the same structure for the same inputs."""
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||||
|
)
|
||||||
|
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
audio = ds[40]["audio"]["array"]
|
||||||
|
|
||||||
|
inputs = {"raw": audio, "sampling_rate": 16_000, "id": 1}
|
||||||
|
|
||||||
|
chunked_output = speech_recognizer(inputs.copy(), chunk_length_s=30)
|
||||||
|
non_chunked_output = speech_recognizer(inputs.copy())
|
||||||
|
assert (
|
||||||
|
chunked_output.keys() == non_chunked_output.keys()
|
||||||
|
), "The output structure should be the same for chunked vs non-chunked versions of asr pipelines."
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_return_timestamps_ctc_fast(self):
|
def test_return_timestamps_ctc_fast(self):
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
|
Loading…
Reference in New Issue
Block a user