mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Whisper] Use torch for stft if available (#26119)
* [Whisper] Use torch for stft if available * update docstring * mock patch decorator * fit on one line
This commit is contained in:
parent
7e93ce40c5
commit
814619f54f
@ -19,12 +19,16 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import is_torch_available
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
||||
"""
|
||||
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
|
||||
"""
|
||||
waveform = torch.from_numpy(waveform).type(torch.float32)
|
||||
|
||||
window = torch.hann_window(self.n_fft)
|
||||
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
||||
mel_spec = mel_filters.T @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.numpy()
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
||||
def zero_mean_unit_var_norm(
|
||||
@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s).
|
||||
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
|
||||
the STFT computation if available, otherwise a slower NumPy based one.
|
||||
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
# make sure list is in array format
|
||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||
|
||||
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
|
||||
extract_fbank_features = (
|
||||
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
|
||||
)
|
||||
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
|
||||
|
||||
if isinstance(input_features[0], List):
|
||||
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||
|
@ -23,16 +23,13 @@ import unittest
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
from transformers import WhisperFeatureExtractor
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import WhisperFeatureExtractor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class WhisperFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
|
||||
return speech_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
|
||||
feature_extraction_class = WhisperFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
|
||||
@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
@require_torch
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
||||
@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_integration(self):
|
||||
@require_torch
|
||||
def test_torch_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[
|
||||
@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
|
||||
def test_numpy_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = np.array(
|
||||
[
|
||||
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
|
||||
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
|
||||
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
|
||||
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="np").input_features
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
audio = self._load_datasamples(1)[0]
|
||||
|
Loading…
Reference in New Issue
Block a user