[Whisper] Use torch for stft if available (#26119)

* [Whisper] Use torch for stft if available

* update docstring

* mock patch decorator

* fit on one line
This commit is contained in:
Sanchit Gandhi 2023-12-21 11:04:05 +00:00 committed by GitHub
parent 7e93ce40c5
commit 814619f54f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 13 deletions

View File

@ -19,12 +19,16 @@ from typing import List, Optional, Union
import numpy as np
from ... import is_torch_available
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
window = torch.hann_window(self.n_fft)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.numpy()
@staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
**kwargs,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
the STFT computation if available, otherwise a slower NumPy based one.
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

View File

@ -23,16 +23,13 @@ import unittest
import numpy as np
from datasets import load_dataset
from transformers import is_speech_available
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import WhisperFeatureExtractor
if is_torch_available():
import torch
@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
return speech_inputs
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
feature_extraction_class = WhisperFeatureExtractor
def setUp(self):
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_double_precision_pad(self):
import torch
@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
return [x["array"] for x in speech_samples]
def test_integration(self):
@require_torch
def test_torch_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
def test_numpy_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = np.array(
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
audio = self._load_datasamples(1)[0]