mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add numpy alternative to FE using torchaudio (#26339)
* add audio_utils usage in the FE of SpeechToText * clean unecessary parameters of AudioSpectrogramTransformer FE * add audio_utils usage in AST * add serialization tests and function to FEs * make style * remove use_torchaudio and move to_dict to FE * test audio_utils usage * make style and fix import (remove torchaudio dependency import) * fix torch dependency for jax and tensor tests * fix typo * clean tests with suggestions * add lines to test if is_speech_availble is False
This commit is contained in:
parent
e264745051
commit
be74b2ead6
@ -146,6 +146,7 @@ _import_structure = {
|
||||
"models.audio_spectrogram_transformer": [
|
||||
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"ASTConfig",
|
||||
"ASTFeatureExtractor",
|
||||
],
|
||||
"models.auto": [
|
||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
@ -535,6 +536,7 @@ _import_structure = {
|
||||
"models.speech_to_text": [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
"Speech2TextFeatureExtractor",
|
||||
"Speech2TextProcessor",
|
||||
],
|
||||
"models.speech_to_text_2": [
|
||||
@ -913,20 +915,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
|
||||
|
||||
# Speech-specific objects
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_speech_objects
|
||||
|
||||
_import_structure["utils.dummy_speech_objects"] = [
|
||||
name for name in dir(dummy_speech_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
|
||||
|
||||
# Tensorflow-text-specific objects
|
||||
try:
|
||||
if not is_tensorflow_text_available():
|
||||
@ -4352,6 +4340,7 @@ if TYPE_CHECKING:
|
||||
from .models.audio_spectrogram_transformer import (
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ASTConfig,
|
||||
ASTFeatureExtractor,
|
||||
)
|
||||
from .models.auto import (
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -4722,6 +4711,7 @@ if TYPE_CHECKING:
|
||||
from .models.speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
Speech2TextConfig,
|
||||
Speech2TextFeatureExtractor,
|
||||
Speech2TextProcessor,
|
||||
)
|
||||
from .models.speech_to_text_2 import (
|
||||
@ -5067,15 +5057,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
|
||||
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_speech_objects import *
|
||||
else:
|
||||
from .models.audio_spectrogram_transformer import ASTFeatureExtractor
|
||||
from .models.speech_to_text import Speech2TextFeatureExtractor
|
||||
|
||||
try:
|
||||
if not is_tensorflow_text_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
@ -584,14 +584,15 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
|
||||
Serializes this instance to a Python dictionary. Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["feature_extractor_type"] = self.__class__.__name__
|
||||
|
||||
if "mel_filters" in output:
|
||||
del output["mel_filters"]
|
||||
if "window" in output:
|
||||
del output["window"]
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
|
@ -13,14 +13,15 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_audio_spectrogram_transformer": [
|
||||
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"ASTConfig",
|
||||
]
|
||||
],
|
||||
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -36,19 +37,13 @@ else:
|
||||
"ASTPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_audio_spectrogram_transformer import (
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ASTConfig,
|
||||
)
|
||||
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@ -63,14 +58,6 @@ if TYPE_CHECKING:
|
||||
ASTPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
@ -19,12 +19,18 @@ Feature extractor class for Audio Spectrogram Transformer.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import TensorType, logging
|
||||
from ...utils import TensorType, is_speech_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed
|
||||
length and normalizes them using a mean and standard deviation.
|
||||
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
|
||||
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 1):
|
||||
@ -83,6 +89,21 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.std = std
|
||||
self.return_attention_mask = return_attention_mask
|
||||
|
||||
if not is_speech_available():
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=256,
|
||||
num_mel_filters=self.num_mel_bins,
|
||||
min_frequency=20,
|
||||
max_frequency=sampling_rate // 2,
|
||||
sampling_rate=sampling_rate,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
|
||||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
||||
self.window = window_function(400, "hann", periodic=False)
|
||||
|
||||
def _extract_fbank_features(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
@ -93,17 +114,32 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
||||
and hence the waveform should not be normalized before feature extraction.
|
||||
"""
|
||||
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
fbank = ta_kaldi.fbank(
|
||||
waveform,
|
||||
htk_compat=True,
|
||||
sample_frequency=self.sampling_rate,
|
||||
use_energy=False,
|
||||
window_type="hanning",
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
dither=0.0,
|
||||
frame_shift=10,
|
||||
)
|
||||
if is_speech_available():
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
fbank = ta_kaldi.fbank(
|
||||
waveform,
|
||||
sample_frequency=self.sampling_rate,
|
||||
window_type="hanning",
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
)
|
||||
else:
|
||||
waveform = np.squeeze(waveform)
|
||||
fbank = spectrogram(
|
||||
waveform,
|
||||
self.window,
|
||||
frame_length=400,
|
||||
hop_length=160,
|
||||
fft_length=512,
|
||||
power=2.0,
|
||||
center=False,
|
||||
preemphasis=0.97,
|
||||
mel_filters=self.mel_filters,
|
||||
log_mel="log",
|
||||
mel_floor=1.192092955078125e-07,
|
||||
remove_dc_offset=True,
|
||||
).T
|
||||
|
||||
fbank = torch.from_numpy(fbank)
|
||||
|
||||
n_frames = fbank.shape[0]
|
||||
difference = max_length - n_frames
|
||||
|
@ -17,7 +17,6 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_sentencepiece_available,
|
||||
is_speech_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
@ -25,6 +24,7 @@ from ...utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
|
||||
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
|
||||
"processing_speech_to_text": ["Speech2TextProcessor"],
|
||||
}
|
||||
|
||||
@ -36,14 +36,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
||||
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@ -73,6 +65,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
||||
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
|
||||
from .processing_speech_to_text import Speech2TextProcessor
|
||||
|
||||
try:
|
||||
@ -83,14 +76,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .tokenization_speech_to_text import Speech2TextTokenizer
|
||||
|
||||
try:
|
||||
if not is_speech_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
@ -19,14 +19,17 @@ Feature extractor class for Speech2Text
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
from ...utils import PaddingStrategy, TensorType, is_speech_available, logging
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users
|
||||
should refer to this superclass for more information regarding those methods.
|
||||
|
||||
This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral
|
||||
mean and variance normalization to the extracted features.
|
||||
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
|
||||
otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 80):
|
||||
@ -77,6 +80,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.normalize_vars = normalize_vars
|
||||
self.return_attention_mask = True
|
||||
|
||||
if not is_speech_available():
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=256,
|
||||
num_mel_filters=self.num_mel_bins,
|
||||
min_frequency=20,
|
||||
max_frequency=sampling_rate // 2,
|
||||
sampling_rate=sampling_rate,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
|
||||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
||||
self.window = window_function(400, "povey", periodic=False)
|
||||
|
||||
def _extract_fbank_features(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
@ -86,9 +104,27 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
and hence the waveform should not be normalized before feature extraction.
|
||||
"""
|
||||
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
|
||||
return features.numpy()
|
||||
if is_speech_available():
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
|
||||
features = features.numpy()
|
||||
else:
|
||||
waveform = np.squeeze(waveform)
|
||||
features = spectrogram(
|
||||
waveform,
|
||||
self.window,
|
||||
frame_length=400,
|
||||
hop_length=160,
|
||||
fft_length=512,
|
||||
power=2.0,
|
||||
center=False,
|
||||
preemphasis=0.97,
|
||||
mel_filters=self.mel_filters,
|
||||
log_mel="log",
|
||||
mel_floor=1.192092955078125e-07,
|
||||
remove_dc_offset=True,
|
||||
).T
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def utterance_cmvn(
|
||||
|
@ -15,13 +15,15 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ASTFeatureExtractor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
|
||||
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
|
||||
self.assertEquals(input_values.shape, (1, 1024, 128))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertDictEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
||||
# exact same tests than before, except that we simulate that torchaudio is not available
|
||||
@require_torch
|
||||
@unittest.mock.patch(
|
||||
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
|
||||
lambda: False,
|
||||
)
|
||||
class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest):
|
||||
def test_using_audio_utils(self):
|
||||
# Tests that it uses audio_utils instead of torchaudio
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
self.assertTrue(hasattr(feat_extract, "window"))
|
||||
self.assertTrue(hasattr(feat_extract, "mel_filters"))
|
||||
|
||||
from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import (
|
||||
is_speech_available,
|
||||
)
|
||||
|
||||
self.assertFalse(is_speech_available())
|
||||
|
@ -15,20 +15,19 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None
|
||||
feature_extraction_class = Speech2TextFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
|
||||
@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEquals(input_features.shape, (1, 584, 24))
|
||||
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertDictEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
||||
# exact same tests than before, except that we simulate that torchaudio is not available
|
||||
@require_torch
|
||||
@unittest.mock.patch(
|
||||
"transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False
|
||||
)
|
||||
class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest):
|
||||
def test_using_audio_utils(self):
|
||||
# Tests that it uses audio_utils instead of torchaudio
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
self.assertTrue(hasattr(feat_extract, "window"))
|
||||
self.assertTrue(hasattr(feat_extract, "mel_filters"))
|
||||
|
||||
from transformers.models.speech_to_text.feature_extraction_speech_to_text import is_speech_available
|
||||
|
||||
self.assertFalse(is_speech_available())
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import Speech2TextTokenizer, is_speech_available
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer
|
||||
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
from .test_feature_extraction_speech_to_text import floats_list
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
|
||||
|
||||
|
||||
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user