mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)
* Fix loading pipelines with wav2vec models with lm when in local paths * Adding tests * Fix test * Adding tests * Flake8 fixes * Removing conflict files :( * Adding task type to test * Remove unnecessary test and imports
This commit is contained in:
parent
e1cbc073bf
commit
9eb7e9ba1d
@ -127,7 +127,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
|
||||
feature_extractor, tokenizer = super()._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path):
|
||||
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
|
||||
else:
|
||||
# BeamSearchDecoderCTC has no auto class
|
||||
|
@ -621,15 +621,20 @@ def pipeline(
|
||||
import kenlm # to trigger `ImportError` if not installed
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||
allow_regex = [language_model_glob, alphabet_filename]
|
||||
if os.path.isdir(model_name) or os.path.isfile(model_name):
|
||||
decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
|
||||
else:
|
||||
language_model_glob = os.path.join(
|
||||
BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
|
||||
)
|
||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||
allow_regex = [language_model_glob, alphabet_filename]
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
|
||||
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
|
||||
kwargs["decoder"] = decoder
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||
f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||
)
|
||||
|
||||
if task == "translation" and model.config.task_specific_params:
|
||||
|
@ -18,6 +18,7 @@ import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_with_local_lm_fast(self):
|
||||
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model=local_dir,
|
||||
)
|
||||
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 2
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking(self):
|
||||
|
@ -31,6 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from huggingface_hub import snapshot_download
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
|
||||
@ -303,3 +304,20 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
|
||||
# are downloaded and none of the rest (e.g. README.md, ...)
|
||||
self.assertListEqual(downloaded_decoder_files, expected_decoder_files)
|
||||
|
||||
def test_decoder_local_files(self):
|
||||
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
|
||||
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained(local_dir)
|
||||
|
||||
language_model = processor.decoder.model_container[processor.decoder._model_key]
|
||||
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
|
||||
|
||||
local_decoder_files = os.listdir(local_dir)
|
||||
expected_decoder_files = os.listdir(path_to_cached_dir)
|
||||
|
||||
local_decoder_files.sort()
|
||||
expected_decoder_files.sort()
|
||||
|
||||
# test that both decoder form hub and local files in cache are the same
|
||||
self.assertListEqual(local_decoder_files, expected_decoder_files)
|
||||
|
Loading…
Reference in New Issue
Block a user