transformers/tests/test_pipelines_automatic_speech_recognition.py
Nicolas Patry b058490ceb
ChunkPipeline (batch_size enabled on zero-cls and qa pipelines. (#14225)
* Pipeline chunks.

* Batching for Chunking pipelines ?

* Batching for `question-answering` and `zero-shot-cls`.

* Fixing for FNet.

* Making ASR a chunk pipeline.

* Chunking ASR API.

* doc style.

* Fixing ASR test.

* Fixing QA eror (p_mask, padding is 1, not 0).

* Enable both vad and simple chunking.

* Max length for vad.

* remove inference mode, crashing on s2t.

* Revert ChunkPipeline for ASRpipeline.

Too many knobs for simple integration within the pipeline, better stick
to external convenience functions instead, more control to be had,
simpler pipeline and also easier to replace with other things later.

* Drop necessity for PT for these.

* Enabling generators.

* Add mic + cleanup.

* Typo.

* Typo2.

* Remove ASR work, it does not belong in this PR anymore.

* Update src/transformers/pipelines/pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/pipelines/zero_shot_classification.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Adding many comments.

* Doc quality.

* `hidden_states` handling.

* Adding doc.

* Bad rebase.

* Autofixing docs.

* Fixing CRITICAL bug in the new Zerocls pipeline.

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2021-12-27 11:26:20 +01:00

218 lines
8.2 KiB
Python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
from datasets import load_dataset
from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
from .test_pipelines_common import ANY, PipelineTestCaseMeta
# We can't use this mixin because it assumes TF support.
# from .test_pipelines_common import CustomInputPipelineCommonMixin
@is_pipeline_test
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = {
k: v
for k, v in (list(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items()) if MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else [])
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
}
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small
self.skipTest("No tokenizer available")
return
# return None, None
speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
# test with a raw waveform
audio = np.zeros((34000,))
audio2 = np.zeros((14000,))
return speech_recognizer, [audio, audio2]
def run_pipeline_test(self, speech_recognizer, examples):
audio = np.zeros((34000,))
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
@require_torch
@slow
def test_pt_defaults(self):
pipeline("automatic-speech-recognition", framework="pt")
@require_torch
def test_small_model_pt(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
@require_tf
def test_small_model_tf(self):
self.skipTest("Tensorflow not supported yet.")
@require_torch
def test_torch_small_no_tokenizer_files(self):
# test that model without tokenizer file cannot be loaded
with pytest.raises(OSError):
pipeline(
task="automatic-speech-recognition",
model="patrickvonplaten/tiny-wav2vec2-no-tokenizer",
framework="pt",
)
@require_torch
@slow
def test_torch_large(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-base-960h",
tokenizer="facebook/wav2vec2-base-960h",
framework="pt",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": ""})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-wav2vec2-large-en-de",
feature_extractor="facebook/s2t-wav2vec2-large-en-de",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
@slow
@require_torch
def test_simple_wav2vec2(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = asr(waveform)
self.assertEqual(output, {"text": ""})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
output = asr(data)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@slow
@require_torch
@require_torchaudio
def test_simple_s2t(self):
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-it-st")
tokenizer = AutoTokenizer.from_pretrained("facebook/s2t-small-mustc-en-it-st")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/s2t-small-mustc-en-it-st")
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = asr(waveform)
self.assertEqual(output, {"text": "(Applausi)"})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
@slow
@require_torch
@require_torchaudio
def test_xls_r_to_en(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-xls-r-1b-21-to-en",
feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
@slow
@require_torch
@require_torchaudio
def test_xls_r_from_en(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-xls-r-1b-en-to-15",
feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})