[Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683)

This commit is contained in:
Patrick von Platen 2022-02-16 17:33:33 +01:00 committed by GitHub
parent cdc51ffd27
commit 3a4376d008
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 0 deletions

View File

@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
else:
# BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None)
# snapshot_download has no `trust_remote_code` flag
kwargs.pop("trust_remote_code", None)
# make sure that only relevant filenames are downloaded
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")

View File

@ -22,6 +22,7 @@ from pathlib import Path
import numpy as np
from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files)
def test_processor_from_auto_processor(self):
processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
processor_auto = AutoProcessor.from_pretrained("hf-internal-testing/processor_with_lm")
raw_speech = floats_list((3, 1000))
input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np")
input_auto = processor_auto(raw_speech, return_tensors="np")
for key in input_wav2vec2.keys():
self.assertAlmostEqual(input_wav2vec2[key].sum(), input_auto[key].sum(), delta=1e-2)
logits = self._get_dummy_logits()
decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits)
decoded_auto = processor_auto.batch_decode(logits)
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)