mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)
This commit is contained in:
parent
6ea6266625
commit
efb35a4107
@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM:
|
||||
# BeamSearchDecoderCTC has no auto class
|
||||
kwargs.pop("_from_auto", None)
|
||||
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
|
||||
# make sure that only relevant filenames are downloaded
|
||||
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||
allow_regex = [language_model_filenames, alphabet_filename]
|
||||
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
|
||||
pretrained_model_name_or_path, allow_regex=allow_regex, **kwargs
|
||||
)
|
||||
|
||||
# set language model attributes
|
||||
for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]:
|
||||
|
@ -18,6 +18,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(decoded_decoder, decoded_processor)
|
||||
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
|
||||
|
||||
def test_decoder_download_ignores_files(self):
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
|
||||
language_model = processor.decoder.model_container[processor.decoder._model_key]
|
||||
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
|
||||
|
||||
downloaded_decoder_files = os.listdir(path_to_cached_dir)
|
||||
|
||||
# test that only decoder relevant files from
|
||||
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
|
||||
# are downloaded and none of the rest (e.g. README.md, ...)
|
||||
self.assertListEqual(downloaded_decoder_files, ["alphabet.json", "language_model"])
|
||||
|
Loading…
Reference in New Issue
Block a user