Add numpy alternative to FE using torchaudio (#26339)

* add audio_utils usage in the FE of SpeechToText

* clean unecessary parameters of AudioSpectrogramTransformer FE

* add audio_utils usage in AST

* add serialization tests and function to FEs

* make style

* remove use_torchaudio and move to_dict to FE

* test audio_utils usage

* make style and fix import (remove torchaudio dependency import)

* fix torch dependency for jax and tensor tests

* fix typo

* clean tests with suggestions

* add lines to test if is_speech_availble is False
This commit is contained in:
Yoach Lacombe 2023-11-08 07:39:37 +00:00 committed by GitHub
parent e264745051
commit be74b2ead6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 208 additions and 98 deletions

View File

@ -146,6 +146,7 @@ _import_structure = {
"models.audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
"ASTFeatureExtractor",
],
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -535,6 +536,7 @@ _import_structure = {
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
"Speech2TextProcessor",
],
"models.speech_to_text_2": [
@ -913,20 +915,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
# Speech-specific objects
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_speech_objects
_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
else:
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
# Tensorflow-text-specific objects
try:
if not is_tensorflow_text_available():
@ -4352,6 +4340,7 @@ if TYPE_CHECKING:
from .models.audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
ASTFeatureExtractor,
)
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -4722,6 +4711,7 @@ if TYPE_CHECKING:
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
Speech2TextProcessor,
)
from .models.speech_to_text_2 import (
@ -5067,15 +5057,6 @@ if TYPE_CHECKING:
else:
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
else:
from .models.audio_spectrogram_transformer import ASTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor
try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()

View File

@ -584,14 +584,15 @@ class FeatureExtractionMixin(PushToHubMixin):
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
if "window" in output:
del output["window"]
return output
@classmethod

View File

@ -13,14 +13,15 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
]
],
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
}
try:
@ -36,19 +37,13 @@ else:
"ASTPreTrainedModel",
]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]
if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
)
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
try:
if not is_torch_available():
@ -63,14 +58,6 @@ if TYPE_CHECKING:
ASTPreTrainedModel,
)
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
else:
import sys

View File

@ -19,12 +19,18 @@ Feature extractor class for Audio Spectrogram Transformer.
from typing import List, Optional, Union
import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
from ...utils import TensorType, is_speech_available, is_torch_available, logging
if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed
length and normalizes them using a mean and standard deviation.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
Args:
feature_size (`int`, *optional*, defaults to 1):
@ -83,6 +89,21 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
self.std = std
self.return_attention_mask = return_attention_mask
if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "hann", periodic=False)
def _extract_fbank_features(
self,
waveform: np.ndarray,
@ -93,17 +114,32 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
and hence the waveform should not be normalized before feature extraction.
"""
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=self.sampling_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
dither=0.0,
frame_shift=10,
)
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank(
waveform,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
)
else:
waveform = np.squeeze(waveform)
fbank = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
fbank = torch.from_numpy(fbank)
n_frames = fbank.shape[0]
difference = max_length - n_frames

View File

@ -17,7 +17,6 @@ from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_speech_available,
is_tf_available,
is_torch_available,
)
@ -25,6 +24,7 @@ from ...utils import (
_import_structure = {
"configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
"processing_speech_to_text": ["Speech2TextProcessor"],
}
@ -36,14 +36,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
@ -73,6 +65,7 @@ else:
if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
from .processing_speech_to_text import Speech2TextProcessor
try:
@ -83,14 +76,6 @@ if TYPE_CHECKING:
else:
from .tokenization_speech_to_text import Speech2TextTokenizer
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()

View File

@ -19,14 +19,17 @@ Feature extractor class for Speech2Text
from typing import List, Optional, Union
import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
from ...utils import PaddingStrategy, TensorType, is_speech_available, logging
if is_speech_available():
import torch
import torchaudio.compliance.kaldi as ta_kaldi
logger = logging.get_logger(__name__)
@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral
mean and variance normalization to the extracted features.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features.
Args:
feature_size (`int`, *optional*, defaults to 80):
@ -77,6 +80,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
self.normalize_vars = normalize_vars
self.return_attention_mask = True
if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "povey", periodic=False)
def _extract_fbank_features(
self,
waveform: np.ndarray,
@ -86,9 +104,27 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
and hence the waveform should not be normalized before feature extraction.
"""
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
return features.numpy()
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
features = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
return features
@staticmethod
def utterance_cmvn(

View File

@ -15,13 +15,15 @@
import itertools
import os
import random
import tempfile
import unittest
import numpy as np
from transformers import ASTFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 1024, 128))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertDictEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second)
# exact same tests than before, except that we simulate that torchaudio is not available
@require_torch
@unittest.mock.patch(
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
lambda: False,
)
class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest):
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import (
is_speech_available,
)
self.assertFalse(is_speech_available())

View File

@ -15,20 +15,19 @@
import itertools
import os
import random
import tempfile
import unittest
import numpy as np
from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio
from transformers import Speech2TextFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import Speech2TextFeatureExtractor
global_rng = random.Random()
@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
@require_torch
@require_torchaudio
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None
feature_extraction_class = Speech2TextFeatureExtractor
def setUp(self):
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEquals(input_features.shape, (1, 584, 24))
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertDictEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second)
# exact same tests than before, except that we simulate that torchaudio is not available
@require_torch
@unittest.mock.patch(
"transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False
)
class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest):
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
from transformers.models.speech_to_text.feature_extraction_speech_to_text import is_speech_available
self.assertFalse(is_speech_available())

View File

@ -18,7 +18,7 @@ import unittest
from pathlib import Path
from shutil import copyfile
from transformers import Speech2TextTokenizer, is_speech_available
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio
from transformers.utils import FEATURE_EXTRACTOR_NAME
@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME
from .test_feature_extraction_speech_to_text import floats_list
if is_speech_available():
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")