From be74b2ead69df1849ec62ac5c86c7d5dee663448 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Wed, 8 Nov 2023 07:39:37 +0000 Subject: [PATCH] Add numpy alternative to FE using torchaudio (#26339) * add audio_utils usage in the FE of SpeechToText * clean unecessary parameters of AudioSpectrogramTransformer FE * add audio_utils usage in AST * add serialization tests and function to FEs * make style * remove use_torchaudio and move to_dict to FE * test audio_utils usage * make style and fix import (remove torchaudio dependency import) * fix torch dependency for jax and tensor tests * fix typo * clean tests with suggestions * add lines to test if is_speech_availble is False --- src/transformers/__init__.py | 27 ++------ src/transformers/feature_extraction_utils.py | 11 +-- .../audio_spectrogram_transformer/__init__.py | 21 ++---- ...xtraction_audio_spectrogram_transformer.py | 68 ++++++++++++++----- .../models/speech_to_text/__init__.py | 19 +----- .../feature_extraction_speech_to_text.py | 52 +++++++++++--- ...xtraction_audio_spectrogram_transformer.py | 49 ++++++++++++- .../test_feature_extraction_speech_to_text.py | 53 +++++++++++++-- .../test_processor_speech_to_text.py | 6 +- 9 files changed, 208 insertions(+), 98 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 97cc4e578c7..4e98a717f02 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -146,6 +146,7 @@ _import_structure = { "models.audio_spectrogram_transformer": [ "AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ASTConfig", + "ASTFeatureExtractor", ], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -535,6 +536,7 @@ _import_structure = { "models.speech_to_text": [ "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig", + "Speech2TextFeatureExtractor", "Speech2TextProcessor", ], "models.speech_to_text_2": [ @@ -913,20 +915,6 @@ except OptionalDependencyNotAvailable: else: _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] -# Speech-specific objects -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_speech_objects - - _import_structure["utils.dummy_speech_objects"] = [ - name for name in dir(dummy_speech_objects) if not name.startswith("_") - ] -else: - _import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor") - _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") - # Tensorflow-text-specific objects try: if not is_tensorflow_text_available(): @@ -4352,6 +4340,7 @@ if TYPE_CHECKING: from .models.audio_spectrogram_transformer import ( AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ASTConfig, + ASTFeatureExtractor, ) from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -4722,6 +4711,7 @@ if TYPE_CHECKING: from .models.speech_to_text import ( SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig, + Speech2TextFeatureExtractor, Speech2TextProcessor, ) from .models.speech_to_text_2 import ( @@ -5067,15 +5057,6 @@ if TYPE_CHECKING: else: from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_speech_objects import * - else: - from .models.audio_spectrogram_transformer import ASTFeatureExtractor - from .models.speech_to_text import Speech2TextFeatureExtractor - try: if not is_tensorflow_text_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index b626ff3dd71..fe1f7a78c93 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -584,14 +584,15 @@ class FeatureExtractionMixin(PushToHubMixin): def to_dict(self) -> Dict[str, Any]: """ - Serializes this instance to a Python dictionary. - - Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) output["feature_extractor_type"] = self.__class__.__name__ - + if "mel_filters" in output: + del output["mel_filters"] + if "window" in output: + del output["window"] return output @classmethod diff --git a/src/transformers/models/audio_spectrogram_transformer/__init__.py b/src/transformers/models/audio_spectrogram_transformer/__init__.py index 9aa42423cf5..2b48fe07311 100644 --- a/src/transformers/models/audio_spectrogram_transformer/__init__.py +++ b/src/transformers/models/audio_spectrogram_transformer/__init__.py @@ -13,14 +13,15 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available _import_structure = { "configuration_audio_spectrogram_transformer": [ "AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ASTConfig", - ] + ], + "feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"], } try: @@ -36,19 +37,13 @@ else: "ASTPreTrainedModel", ] -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"] if TYPE_CHECKING: from .configuration_audio_spectrogram_transformer import ( AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ASTConfig, ) + from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor try: if not is_torch_available(): @@ -63,14 +58,6 @@ if TYPE_CHECKING: ASTPreTrainedModel, ) - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor - else: import sys diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py index 786548fd233..2bd122b4098 100644 --- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -19,12 +19,18 @@ Feature extractor class for Audio Spectrogram Transformer. from typing import List, Optional, Union import numpy as np -import torch -import torchaudio.compliance.kaldi as ta_kaldi +from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging +from ...utils import TensorType, is_speech_available, is_torch_available, logging + + +if is_speech_available(): + import torchaudio.compliance.kaldi as ta_kaldi + +if is_torch_available(): + import torch logger = logging.get_logger(__name__) @@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. - This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed - length and normalizes them using a mean and standard deviation. + This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy + otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation. Args: feature_size (`int`, *optional*, defaults to 1): @@ -83,6 +89,21 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): self.std = std self.return_attention_mask = return_attention_mask + if not is_speech_available(): + mel_filters = mel_filter_bank( + num_frequency_bins=256, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.window = window_function(400, "hann", periodic=False) + def _extract_fbank_features( self, waveform: np.ndarray, @@ -93,17 +114,32 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): and hence the waveform should not be normalized before feature extraction. """ # waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers - waveform = torch.from_numpy(waveform).unsqueeze(0) - fbank = ta_kaldi.fbank( - waveform, - htk_compat=True, - sample_frequency=self.sampling_rate, - use_energy=False, - window_type="hanning", - num_mel_bins=self.num_mel_bins, - dither=0.0, - frame_shift=10, - ) + if is_speech_available(): + waveform = torch.from_numpy(waveform).unsqueeze(0) + fbank = ta_kaldi.fbank( + waveform, + sample_frequency=self.sampling_rate, + window_type="hanning", + num_mel_bins=self.num_mel_bins, + ) + else: + waveform = np.squeeze(waveform) + fbank = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + + fbank = torch.from_numpy(fbank) n_frames = fbank.shape[0] difference = max_length - n_frames diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py index 45a91c2b496..3194f99931a 100644 --- a/src/transformers/models/speech_to_text/__init__.py +++ b/src/transformers/models/speech_to_text/__init__.py @@ -17,7 +17,6 @@ from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, - is_speech_available, is_tf_available, is_torch_available, ) @@ -25,6 +24,7 @@ from ...utils import ( _import_structure = { "configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"], + "feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"], "processing_speech_to_text": ["Speech2TextProcessor"], } @@ -36,14 +36,6 @@ except OptionalDependencyNotAvailable: else: _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"] - try: if not is_tf_available(): raise OptionalDependencyNotAvailable() @@ -73,6 +65,7 @@ else: if TYPE_CHECKING: from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig + from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor from .processing_speech_to_text import Speech2TextProcessor try: @@ -83,14 +76,6 @@ if TYPE_CHECKING: else: from .tokenization_speech_to_text import Speech2TextTokenizer - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor - try: if not is_tf_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index 0d5b077c938..193f2dda094 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -19,14 +19,17 @@ Feature extractor class for Speech2Text from typing import List, Optional, Union import numpy as np -import torch -import torchaudio.compliance.kaldi as ta_kaldi +from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging +from ...utils import PaddingStrategy, TensorType, is_speech_available, logging +if is_speech_available(): + import torch + import torchaudio.compliance.kaldi as ta_kaldi + logger = logging.get_logger(__name__) @@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. - This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral - mean and variance normalization to the extracted features. + This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy + otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features. Args: feature_size (`int`, *optional*, defaults to 80): @@ -77,6 +80,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): self.normalize_vars = normalize_vars self.return_attention_mask = True + if not is_speech_available(): + mel_filters = mel_filter_bank( + num_frequency_bins=256, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.window = window_function(400, "povey", periodic=False) + def _extract_fbank_features( self, waveform: np.ndarray, @@ -86,9 +104,27 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): and hence the waveform should not be normalized before feature extraction. """ waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers - waveform = torch.from_numpy(waveform).unsqueeze(0) - features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate) - return features.numpy() + if is_speech_available(): + waveform = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate) + features = features.numpy() + else: + waveform = np.squeeze(waveform) + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features @staticmethod def utterance_cmvn( diff --git a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py index 85d696479c2..ac6cd5eb1fb 100644 --- a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py +++ b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py @@ -15,13 +15,15 @@ import itertools +import os import random +import tempfile import unittest import numpy as np from transformers import ASTFeatureExtractor -from transformers.testing_utils import require_torch, require_torchaudio +from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test input_values = feature_extractor(input_speech, return_tensors="pt").input_values self.assertEquals(input_values.shape, (1, 1024, 128)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertDictEqual(dict_first, dict_second) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertEqual(dict_first, dict_second) + + +# exact same tests than before, except that we simulate that torchaudio is not available +@require_torch +@unittest.mock.patch( + "transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available", + lambda: False, +) +class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest): + def test_using_audio_utils(self): + # Tests that it uses audio_utils instead of torchaudio + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + + self.assertTrue(hasattr(feat_extract, "window")) + self.assertTrue(hasattr(feat_extract, "mel_filters")) + + from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import ( + is_speech_available, + ) + + self.assertFalse(is_speech_available()) diff --git a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py index d8929c4ef0d..f652d09ffca 100644 --- a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py +++ b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py @@ -15,20 +15,19 @@ import itertools +import os import random +import tempfile import unittest import numpy as np -from transformers import is_speech_available -from transformers.testing_utils import require_torch, require_torchaudio +from transformers import Speech2TextFeatureExtractor +from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin -if is_speech_available(): - from transformers import Speech2TextFeatureExtractor - global_rng = random.Random() @@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase): @require_torch @require_torchaudio class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None + feature_extraction_class = Speech2TextFeatureExtractor def setUp(self): self.feat_extract_tester = Speech2TextFeatureExtractionTester(self) @@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt input_features = feature_extractor(input_speech, return_tensors="pt").input_features self.assertEquals(input_features.shape, (1, 584, 24)) self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4)) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertDictEqual(dict_first, dict_second) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertEqual(dict_first, dict_second) + + +# exact same tests than before, except that we simulate that torchaudio is not available +@require_torch +@unittest.mock.patch( + "transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False +) +class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest): + def test_using_audio_utils(self): + # Tests that it uses audio_utils instead of torchaudio + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + + self.assertTrue(hasattr(feat_extract, "window")) + self.assertTrue(hasattr(feat_extract, "mel_filters")) + + from transformers.models.speech_to_text.feature_extraction_speech_to_text import is_speech_available + + self.assertFalse(is_speech_available()) diff --git a/tests/models/speech_to_text/test_processor_speech_to_text.py b/tests/models/speech_to_text/test_processor_speech_to_text.py index 9b8b3ccf66b..923ba29d1a8 100644 --- a/tests/models/speech_to_text/test_processor_speech_to_text.py +++ b/tests/models/speech_to_text/test_processor_speech_to_text.py @@ -18,7 +18,7 @@ import unittest from pathlib import Path from shutil import copyfile -from transformers import Speech2TextTokenizer, is_speech_available +from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio from transformers.utils import FEATURE_EXTRACTOR_NAME @@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME from .test_feature_extraction_speech_to_text import floats_list -if is_speech_available(): - from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor - - SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")