mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[ASR pipeline] correct with lm pipeline (#15200)
* [ASR pipeline] correct with lm pipeline * improve error
This commit is contained in:
parent
1144d336b6
commit
497346d07e
2
setup.py
2
setup.py
@ -152,7 +152,7 @@ _deps = [
|
|||||||
"tokenizers>=0.10.1,!=0.11.3",
|
"tokenizers>=0.10.1,!=0.11.3",
|
||||||
"torch>=1.0",
|
"torch>=1.0",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"pyctcdecode>=0.2.0",
|
"pyctcdecode>=0.3.0",
|
||||||
"tqdm>=4.27",
|
"tqdm>=4.27",
|
||||||
"unidic>=1.0.2",
|
"unidic>=1.0.2",
|
||||||
"unidic_lite>=1.0.7",
|
"unidic_lite>=1.0.7",
|
||||||
|
@ -62,7 +62,7 @@ deps = {
|
|||||||
"tokenizers": "tokenizers>=0.10.1,!=0.11.3",
|
"tokenizers": "tokenizers>=0.10.1,!=0.11.3",
|
||||||
"torch": "torch>=1.0",
|
"torch": "torch>=1.0",
|
||||||
"torchaudio": "torchaudio",
|
"torchaudio": "torchaudio",
|
||||||
"pyctcdecode": "pyctcdecode>=0.2.0",
|
"pyctcdecode": "pyctcdecode>=0.3.0",
|
||||||
"tqdm": "tqdm>=4.27",
|
"tqdm": "tqdm>=4.27",
|
||||||
"unidic": "unidic>=1.0.2",
|
"unidic": "unidic>=1.0.2",
|
||||||
"unidic_lite": "unidic_lite>=1.0.7",
|
"unidic_lite": "unidic_lite>=1.0.7",
|
||||||
|
@ -489,8 +489,9 @@ class FeatureExtractionMixin:
|
|||||||
|
|
||||||
# make sure private name "_processor_class" is correctly
|
# make sure private name "_processor_class" is correctly
|
||||||
# saved as "processor_class"
|
# saved as "processor_class"
|
||||||
if dictionary.get("_processor_class", None) is not None:
|
_processor_class = dictionary.pop("_processor_class", None)
|
||||||
dictionary["processor_class"] = dictionary.pop("_processor_class")
|
if _processor_class is not None:
|
||||||
|
dictionary["processor_class"] = _processor_class
|
||||||
|
|
||||||
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The HuggingFace Inc. team.
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
@ -617,17 +618,16 @@ def pipeline(
|
|||||||
and isinstance(model_name, str)
|
and isinstance(model_name, str)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
import kenlm # to trigger `ImportError` if not installed
|
||||||
from pyctcdecode import BeamSearchDecoderCTC
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||||
allow_regex = [language_model_glob, alphabet_filename]
|
allow_regex = [language_model_glob, alphabet_filename]
|
||||||
|
|
||||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
|
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
|
||||||
pretrained_model_name_or_path, allow_regex=allow_regex
|
|
||||||
)
|
|
||||||
kwargs["decoder"] = decoder
|
kwargs["decoder"] = decoder
|
||||||
except Exception as e:
|
except ImportError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||||
)
|
)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -42,8 +43,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
|
|
||||||
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
||||||
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
||||||
|
|
||||||
config_dict.pop("feature_extractor_type")
|
config_dict.pop("feature_extractor_type")
|
||||||
config = Wav2Vec2FeatureExtractor(config_dict)
|
config = Wav2Vec2FeatureExtractor(**config_dict)
|
||||||
|
|
||||||
# save in new folder
|
# save in new folder
|
||||||
model_config.save_pretrained(tmpdirname)
|
model_config.save_pretrained(tmpdirname)
|
||||||
@ -51,6 +53,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
|
|
||||||
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# make sure private variable is not incorrectly saved
|
||||||
|
dict_as_saved = json.loads(config.to_json_string())
|
||||||
|
self.assertTrue("_processor_class" not in dict_as_saved)
|
||||||
|
|
||||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|
||||||
def test_feature_extractor_from_local_file(self):
|
def test_feature_extractor_from_local_file(self):
|
||||||
|
@ -295,6 +295,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
self.assertEqual(output, [{"text": ANY(str)}])
|
self.assertEqual(output, [{"text": ANY(str)}])
|
||||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_pyctcdecode
|
||||||
|
def test_with_lm_fast(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="hf-internal-testing/processor_with_lm",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||||
|
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
audio = ds[40]["audio"]["array"]
|
||||||
|
|
||||||
|
n_repeats = 2
|
||||||
|
audio_tiled = np.tile(audio, n_repeats)
|
||||||
|
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||||
|
|
||||||
|
self.assertEqual(output, [{"text": ANY(str)}])
|
||||||
|
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_chunking(self):
|
def test_chunking(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user