Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)

* Fix loading pipelines with wav2vec models with lm when in local paths

* Adding tests

* Fix test

* Adding tests

* Flake8 fixes

* Removing conflict files :(

* Adding task type to test

* Remove unnecessary test and imports
This commit is contained in:
Javier de la Rosa 2022-02-15 13:45:08 +01:00 committed by GitHub
parent e1cbc073bf
commit 9eb7e9ba1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 6 deletions

View File

@ -127,7 +127,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
feature_extractor, tokenizer = super()._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else:
# BeamSearchDecoderCTC has no auto class

View File

@ -621,15 +621,20 @@ def pipeline(
import kenlm # to trigger `ImportError` if not installed
from pyctcdecode import BeamSearchDecoderCTC
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_glob, alphabet_filename]
if os.path.isdir(model_name) or os.path.isfile(model_name):
decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
else:
language_model_glob = os.path.join(
BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
)
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_glob, alphabet_filename]
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
kwargs["decoder"] = decoder
except ImportError as e:
logger.warning(
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)
if task == "translation" and model.config.task_specific_params:

View File

@ -18,6 +18,7 @@ import numpy as np
import pytest
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@require_pyctcdecode
def test_with_local_lm_fast(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model=local_dir,
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_chunking(self):

View File

@ -31,6 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
@ -303,3 +304,20 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self.assertListEqual(downloaded_decoder_files, expected_decoder_files)
def test_decoder_local_files(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained(local_dir)
language_model = processor.decoder.model_container[processor.decoder._model_key]
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
local_decoder_files = os.listdir(local_dir)
expected_decoder_files = os.listdir(path_to_cached_dir)
local_decoder_files.sort()
expected_decoder_files.sort()
# test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files)