mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00

* Pipeline chunks. * Batching for Chunking pipelines ? * Batching for `question-answering` and `zero-shot-cls`. * Fixing for FNet. * Making ASR a chunk pipeline. * Chunking ASR API. * doc style. * Fixing ASR test. * Fixing QA eror (p_mask, padding is 1, not 0). * Enable both vad and simple chunking. * Max length for vad. * remove inference mode, crashing on s2t. * Revert ChunkPipeline for ASRpipeline. Too many knobs for simple integration within the pipeline, better stick to external convenience functions instead, more control to be had, simpler pipeline and also easier to replace with other things later. * Drop necessity for PT for these. * Enabling generators. * Add mic + cleanup. * Typo. * Typo2. * Remove ASR work, it does not belong in this PR anymore. * Update src/transformers/pipelines/pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/pipelines/zero_shot_classification.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Adding many comments. * Doc quality. * `hidden_states` handling. * Adding doc. * Bad rebase. * Autofixing docs. * Fixing CRITICAL bug in the new Zerocls pipeline. Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
218 lines
8.2 KiB
Python
218 lines
8.2 KiB
Python
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from datasets import load_dataset
|
|
|
|
from transformers import (
|
|
MODEL_FOR_CTC_MAPPING,
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
|
AutoFeatureExtractor,
|
|
AutoTokenizer,
|
|
Speech2TextForConditionalGeneration,
|
|
Wav2Vec2ForCTC,
|
|
)
|
|
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
|
|
|
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
|
|
|
|
|
# We can't use this mixin because it assumes TF support.
|
|
# from .test_pipelines_common import CustomInputPipelineCommonMixin
|
|
|
|
|
|
@is_pipeline_test
|
|
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|
model_mapping = {
|
|
k: v
|
|
for k, v in (list(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items()) if MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else [])
|
|
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
|
|
}
|
|
|
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
|
if tokenizer is None:
|
|
# Side effect of no Fast Tokenizer class for these model, so skipping
|
|
# But the slow tokenizer test should still run as they're quite small
|
|
self.skipTest("No tokenizer available")
|
|
return
|
|
# return None, None
|
|
|
|
speech_recognizer = AutomaticSpeechRecognitionPipeline(
|
|
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
|
)
|
|
|
|
# test with a raw waveform
|
|
audio = np.zeros((34000,))
|
|
audio2 = np.zeros((14000,))
|
|
return speech_recognizer, [audio, audio2]
|
|
|
|
def run_pipeline_test(self, speech_recognizer, examples):
|
|
audio = np.zeros((34000,))
|
|
outputs = speech_recognizer(audio)
|
|
self.assertEqual(outputs, {"text": ANY(str)})
|
|
|
|
@require_torch
|
|
@slow
|
|
def test_pt_defaults(self):
|
|
pipeline("automatic-speech-recognition", framework="pt")
|
|
|
|
@require_torch
|
|
def test_small_model_pt(self):
|
|
|
|
speech_recognizer = pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="facebook/s2t-small-mustc-en-fr-st",
|
|
tokenizer="facebook/s2t-small-mustc-en-fr-st",
|
|
framework="pt",
|
|
)
|
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
|
output = speech_recognizer(waveform)
|
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
|
|
|
@require_tf
|
|
def test_small_model_tf(self):
|
|
self.skipTest("Tensorflow not supported yet.")
|
|
|
|
@require_torch
|
|
def test_torch_small_no_tokenizer_files(self):
|
|
# test that model without tokenizer file cannot be loaded
|
|
with pytest.raises(OSError):
|
|
pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="patrickvonplaten/tiny-wav2vec2-no-tokenizer",
|
|
framework="pt",
|
|
)
|
|
|
|
@require_torch
|
|
@slow
|
|
def test_torch_large(self):
|
|
|
|
speech_recognizer = pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="facebook/wav2vec2-base-960h",
|
|
tokenizer="facebook/wav2vec2-base-960h",
|
|
framework="pt",
|
|
)
|
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
|
output = speech_recognizer(waveform)
|
|
self.assertEqual(output, {"text": ""})
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = speech_recognizer(filename)
|
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
|
|
|
@require_torch
|
|
@slow
|
|
def test_torch_speech_encoder_decoder(self):
|
|
speech_recognizer = pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="facebook/s2t-wav2vec2-large-en-de",
|
|
feature_extractor="facebook/s2t-wav2vec2-large-en-de",
|
|
framework="pt",
|
|
)
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = speech_recognizer(filename)
|
|
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
|
|
|
|
@slow
|
|
@require_torch
|
|
def test_simple_wav2vec2(self):
|
|
|
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
|
|
|
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
|
|
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
|
output = asr(waveform)
|
|
self.assertEqual(output, {"text": ""})
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = asr(filename)
|
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
|
|
|
filename = ds[40]["file"]
|
|
with open(filename, "rb") as f:
|
|
data = f.read()
|
|
output = asr(data)
|
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
|
|
|
@slow
|
|
@require_torch
|
|
@require_torchaudio
|
|
def test_simple_s2t(self):
|
|
|
|
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-it-st")
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/s2t-small-mustc-en-it-st")
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/s2t-small-mustc-en-it-st")
|
|
|
|
asr = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
|
|
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
|
|
|
output = asr(waveform)
|
|
self.assertEqual(output, {"text": "(Applausi)"})
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = asr(filename)
|
|
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
|
|
|
filename = ds[40]["file"]
|
|
with open(filename, "rb") as f:
|
|
data = f.read()
|
|
output = asr(data)
|
|
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
|
|
|
@slow
|
|
@require_torch
|
|
@require_torchaudio
|
|
def test_xls_r_to_en(self):
|
|
speech_recognizer = pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="facebook/wav2vec2-xls-r-1b-21-to-en",
|
|
feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en",
|
|
framework="pt",
|
|
)
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = speech_recognizer(filename)
|
|
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
|
|
|
|
@slow
|
|
@require_torch
|
|
@require_torchaudio
|
|
def test_xls_r_from_en(self):
|
|
speech_recognizer = pipeline(
|
|
task="automatic-speech-recognition",
|
|
model="facebook/wav2vec2-xls-r-1b-en-to-15",
|
|
feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15",
|
|
framework="pt",
|
|
)
|
|
|
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
|
filename = ds[40]["file"]
|
|
output = speech_recognizer(filename)
|
|
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|