mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 17:22:25 +06:00
[Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683)
This commit is contained in:
parent
cdc51ffd27
commit
3a4376d008
@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
||||
else:
|
||||
# BeamSearchDecoderCTC has no auto class
|
||||
kwargs.pop("_from_auto", None)
|
||||
# snapshot_download has no `trust_remote_code` flag
|
||||
kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# make sure that only relevant filenames are downloaded
|
||||
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||
|
@ -22,6 +22,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
# test that both decoder form hub and local files in cache are the same
|
||||
self.assertListEqual(local_decoder_files, expected_decoder_files)
|
||||
|
||||
def test_processor_from_auto_processor(self):
|
||||
processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
processor_auto = AutoProcessor.from_pretrained("hf-internal-testing/processor_with_lm")
|
||||
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np")
|
||||
input_auto = processor_auto(raw_speech, return_tensors="np")
|
||||
|
||||
for key in input_wav2vec2.keys():
|
||||
self.assertAlmostEqual(input_wav2vec2[key].sum(), input_auto[key].sum(), delta=1e-2)
|
||||
|
||||
logits = self._get_dummy_logits()
|
||||
|
||||
decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits)
|
||||
decoded_auto = processor_auto.batch_decode(logits)
|
||||
|
||||
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
||||
|
Loading…
Reference in New Issue
Block a user