mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Whisper] Add rescaling function with do_normalize
(#21263)
* add `zero_mean_unit_var_norm` function * normalize before MEL computation * fixup * add simple test * quality * Update tests/models/whisper/test_feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixup * use attention masks if padding was applied * Update based on review Co-authored-by: bofeng huang <bofenghuang7@gmail.com> --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: bofeng huang <bofenghuang7@gmail.com>
This commit is contained in:
parent
b48c7f7b3f
commit
c87654dca1
@ -215,6 +215,29 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
return log_spec
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
||||
def zero_mean_unit_var_norm(
|
||||
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Every array in the list is normalized to have zero mean and unit variance
|
||||
"""
|
||||
if attention_mask is not None:
|
||||
attention_mask = np.array(attention_mask, np.int32)
|
||||
normed_input_values = []
|
||||
|
||||
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
||||
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
||||
if length < normed_slice.shape[0]:
|
||||
normed_slice[length:] = padding_value
|
||||
|
||||
normed_input_values.append(normed_slice)
|
||||
else:
|
||||
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
|
||||
|
||||
return normed_input_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||
@ -225,6 +248,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
padding: Optional[str] = "max_length",
|
||||
max_length: Optional[int] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -266,6 +290,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
pipeline.
|
||||
padding_value (`float`, defaults to 0.0):
|
||||
The value that is used to fill the padding values / vectors.
|
||||
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||
improve the performance of the model.
|
||||
"""
|
||||
|
||||
if sampling_rate is not None:
|
||||
@ -312,6 +339,18 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
# make sure list is in array format
|
||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||
|
||||
if return_attention_mask:
|
||||
# rescale from sample (48000) to feature (3000)
|
||||
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if do_normalize:
|
||||
padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
|
||||
padded_inputs["input_features"],
|
||||
attention_mask=padded_inputs["attention_mask"],
|
||||
padding_value=self.padding_value,
|
||||
)
|
||||
|
||||
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
|
||||
|
||||
if isinstance(input_features[0], List):
|
||||
|
@ -21,6 +21,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
@ -198,8 +199,6 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
@ -222,3 +221,12 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
feaure_extractor = WhisperFeatureExtractor()
|
||||
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
audio = self._load_datasamples(1)[0]
|
||||
audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535 # Rescale to [0, 65535] to show issue
|
||||
audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0]
|
||||
|
||||
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
||||
|
Loading…
Reference in New Issue
Block a user