mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
audio_utils improvements (#21998)
* silly change to allow making a PR * clean up doc comments * simplify hertz_to_mel and mel_to_hertz * fixup * clean up power_to_db * also add amplitude_to_db * move functions * clean up mel_filter_bank * fixup * credit librosa & torchaudio authors * add unit tests * tests for power_to_db and amplitude_to_db * add mel_filter_bank tests * rewrite STFT * add convenience spectrogram function * missing transpose * fewer transposes * add integration test to M-CTC-T * frame length can be either window or FFT length * rewrite stft API * add preemphasis coefficient * move argument * add log option to spectrogram * replace M-CTC-T feature extractor * fix api thing * replace whisper STFT * replace whisper mel filters * replace tvlt's stft * allow alternate window names * replace speecht5 stft * fixup * fix integration tests * fix doc comments * remove manual FFT length calculation * fix docs * go away, deprecation warnings * combine everything into spectrogram function * add deprecated functions back * fixup
This commit is contained in:
parent
431b04d8c4
commit
7f91950901
@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Utilities for `FeatureExtractors`
|
||||
|
||||
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *Mel log spectrogram*.
|
||||
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *log mel spectrogram*.
|
||||
|
||||
|
||||
Most of those are only useful if you are studying the code of the image processors in the library.
|
||||
Most of those are only useful if you are studying the code of the audio processors in the library.
|
||||
|
||||
## Audio Transformations
|
||||
|
||||
@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo
|
||||
|
||||
[[autodoc]] audio_utils.mel_to_hertz
|
||||
|
||||
[[autodoc]] audio_utils.get_mel_filter_banks
|
||||
[[autodoc]] audio_utils.mel_filter_bank
|
||||
|
||||
[[autodoc]] audio_utils.stft
|
||||
[[autodoc]] audio_utils.optimal_fft_length
|
||||
|
||||
[[autodoc]] audio_utils.window_function
|
||||
|
||||
[[autodoc]] audio_utils.spectrogram
|
||||
|
||||
[[autodoc]] audio_utils.power_to_db
|
||||
|
||||
[[autodoc]] audio_utils.fram_wave
|
||||
|
||||
|
||||
[[autodoc]] audio_utils.amplitude_to_db
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,66 +13,61 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Audio processing functions to extract feature from a raw audio. Should all be in numpy to support all frameworks, and
|
||||
remmove unecessary dependencies.
|
||||
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
|
||||
and remove unnecessary dependencies.
|
||||
"""
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
||||
|
||||
def hertz_to_mel(freq: float, mel_scale: str = "htk") -> float:
|
||||
"""Convert Hertz to Mels.
|
||||
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Convert frequency from hertz to mels.
|
||||
|
||||
Args:
|
||||
freqs (`float`):
|
||||
Frequencies in Hertz
|
||||
freq (`float` or `np.ndarray`):
|
||||
The frequency, or multiple frequencies, in hertz (Hz).
|
||||
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
||||
Scale to use, `htk` or `slaney`.
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
mels (`float`):
|
||||
Frequency in Mels
|
||||
`float` or `np.ndarray`: The frequencies on the mel scale.
|
||||
"""
|
||||
|
||||
if mel_scale not in ["slaney", "htk"]:
|
||||
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
||||
|
||||
if mel_scale == "htk":
|
||||
return 2595.0 * math.log10(1.0 + (freq / 700.0))
|
||||
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
||||
|
||||
# Fill in the linear part
|
||||
frequency_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
|
||||
mels = (freq - frequency_min) / f_sp
|
||||
|
||||
# Fill in the log-scale part
|
||||
min_log_hertz = 1000.0
|
||||
min_log_mel = (min_log_hertz - frequency_min) / f_sp
|
||||
logstep = math.log(6.4) / 27.0
|
||||
min_log_mel = 15.0
|
||||
logstep = 27.0 / np.log(6.4)
|
||||
mels = 3.0 * freq / 200.0
|
||||
|
||||
if freq >= min_log_hertz:
|
||||
mels = min_log_mel + math.log(freq / min_log_hertz) / logstep
|
||||
if isinstance(freq, np.ndarray):
|
||||
log_region = freq >= min_log_hertz
|
||||
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
||||
elif freq >= min_log_hertz:
|
||||
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
||||
|
||||
return mels
|
||||
|
||||
|
||||
def mel_to_hertz(mels: np.array, mel_scale: str = "htk") -> np.array:
|
||||
"""Convert mel bin numbers to frequencies.
|
||||
def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Convert frequency from mels to hertz.
|
||||
|
||||
Args:
|
||||
mels (`np.array`):
|
||||
Mel frequencies
|
||||
mels (`float` or `np.ndarray`):
|
||||
The frequency, or multiple frequencies, in mels.
|
||||
mel_scale (`str`, *optional*, `"htk"`):
|
||||
Scale to use: `htk` or `slaney`.
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
freqs (`np.array`):
|
||||
Mels converted to Hertz
|
||||
`float` or `np.ndarray`: The frequencies in hertz.
|
||||
"""
|
||||
|
||||
if mel_scale not in ["slaney", "htk"]:
|
||||
@ -81,50 +76,483 @@ def mel_to_hertz(mels: np.array, mel_scale: str = "htk") -> np.array:
|
||||
if mel_scale == "htk":
|
||||
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
||||
|
||||
# Fill in the linear scale
|
||||
frequency_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = frequency_min + f_sp * mels
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hertz = 1000.0
|
||||
min_log_mel = (min_log_hertz - frequency_min) / f_sp
|
||||
logstep = math.log(6.4) / 27.0
|
||||
min_log_mel = 15.0
|
||||
logstep = np.log(6.4) / 27.0
|
||||
freq = 200.0 * mels / 3.0
|
||||
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hertz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
if isinstance(mels, np.ndarray):
|
||||
log_region = mels >= min_log_mel
|
||||
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
|
||||
elif mels >= min_log_mel:
|
||||
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
|
||||
|
||||
return freqs
|
||||
return freq
|
||||
|
||||
|
||||
def _create_triangular_filterbank(
|
||||
all_freqs: np.array,
|
||||
f_pts: np.array,
|
||||
) -> np.array:
|
||||
"""Create a triangular filter bank.
|
||||
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Creates a triangular filter bank.
|
||||
|
||||
Adapted from *torchaudio* and *librosa*.
|
||||
|
||||
Args:
|
||||
all_freqs (`np.array` of shape (`nb_frequency_bins`, )):
|
||||
Discrete frequencies used when the STFT was computed.
|
||||
f_pts (`np.array`, of shape (`nb_mel_filters`, )):
|
||||
Coordinates of the middle points of the triangular filters to create.
|
||||
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
|
||||
Discrete frequencies of the FFT bins in Hz.
|
||||
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
|
||||
Center frequencies of the triangular filters to create, in Hz.
|
||||
|
||||
Returns:
|
||||
fb (np.array):
|
||||
The filter bank of size (`nb_frequency_bins`, `nb_mel_filters`).
|
||||
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
|
||||
"""
|
||||
# Adapted from Librosa
|
||||
# calculate the difference between each filter mid point and each stft freq point in hertz
|
||||
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
||||
slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (nb_frequency_bins, n_filter + 2)
|
||||
# create overlapping triangles
|
||||
zero = np.zeros(1)
|
||||
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (nb_frequency_bins, n_filter)
|
||||
up_slopes = slopes[:, 2:] / f_diff[1:] # (nb_frequency_bins, n_filter)
|
||||
fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
|
||||
filter_diff = np.diff(filter_freqs)
|
||||
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
||||
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
||||
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
||||
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
||||
|
||||
return fb
|
||||
|
||||
def mel_filter_bank(
|
||||
num_frequency_bins: int,
|
||||
num_mel_filters: int,
|
||||
min_frequency: float,
|
||||
max_frequency: float,
|
||||
sampling_rate: int,
|
||||
norm: Optional[str] = None,
|
||||
mel_scale: str = "htk",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
|
||||
various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
|
||||
are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
|
||||
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
|
||||
|
||||
Different banks of mel filters were introduced in the literature. The following variations are supported:
|
||||
|
||||
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
|
||||
bandwidth of `[0, 4600]` Hz.
|
||||
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
|
||||
bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
|
||||
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
|
||||
speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
|
||||
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
|
||||
12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
|
||||
|
||||
This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
|
||||
`melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
|
||||
|
||||
Args:
|
||||
num_frequency_bins (`int`):
|
||||
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
||||
num_mel_filters (`int`):
|
||||
Number of mel filters to generate.
|
||||
min_frequency (`float`):
|
||||
Lowest frequency of interest in Hz.
|
||||
max_frequency (`float`):
|
||||
Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
|
||||
sampling_rate (`int`):
|
||||
Sample rate of the audio waveform.
|
||||
norm (`str`, *optional*):
|
||||
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
|
||||
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
|
||||
projection matrix to go from a spectrogram to a mel spectrogram.
|
||||
"""
|
||||
if norm is not None and norm != "slaney":
|
||||
raise ValueError('norm must be one of None or "slaney"')
|
||||
|
||||
# frequencies of FFT bins in Hz
|
||||
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
||||
|
||||
# center points of the triangular mel filters
|
||||
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
||||
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
||||
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
||||
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
||||
|
||||
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
||||
|
||||
if norm is not None and norm == "slaney":
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
|
||||
mel_filters *= np.expand_dims(enorm, 0)
|
||||
|
||||
if (mel_filters.max(axis=0) == 0.0).any():
|
||||
warnings.warn(
|
||||
"At least one mel filter has all zero values. "
|
||||
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
|
||||
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
|
||||
)
|
||||
|
||||
return mel_filters
|
||||
|
||||
|
||||
def optimal_fft_length(window_length: int) -> int:
|
||||
"""
|
||||
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
|
||||
already a power of two, rounds it up to the next power or two.
|
||||
|
||||
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
|
||||
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
|
||||
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
|
||||
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
|
||||
"""
|
||||
return 2 ** int(np.ceil(np.log2(window_length)))
|
||||
|
||||
|
||||
def window_function(
|
||||
window_length: int,
|
||||
name: str = "hann",
|
||||
periodic: bool = True,
|
||||
frame_length: Optional[int] = None,
|
||||
center: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Returns an array containing the specified window. This window is intended to be used with `stft`.
|
||||
|
||||
The following window types are supported:
|
||||
|
||||
- `"boxcar"`: a rectangular window
|
||||
- `"hamming"`: the Hamming window
|
||||
- `"hann"`: the Hann window
|
||||
|
||||
Args:
|
||||
window_length (`int`):
|
||||
The length of the window in samples.
|
||||
name (`str`, *optional*, defaults to `"hann"`):
|
||||
The name of the window function.
|
||||
periodic (`bool`, *optional*, defaults to `True`):
|
||||
Whether the window is periodic or symmetric.
|
||||
frame_length (`int`, *optional*):
|
||||
The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
|
||||
than the frame length, so that it will be zero-padded.
|
||||
center (`bool`, *optional*, defaults to `True`):
|
||||
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
|
||||
"""
|
||||
length = window_length + 1 if periodic else window_length
|
||||
|
||||
if name == "boxcar":
|
||||
window = np.ones(length)
|
||||
elif name in ["hamming", "hamming_window"]:
|
||||
window = np.hamming(length)
|
||||
elif name in ["hann", "hann_window"]:
|
||||
window = np.hanning(length)
|
||||
else:
|
||||
raise ValueError(f"Unknown window function '{name}'")
|
||||
|
||||
if periodic:
|
||||
window = window[:-1]
|
||||
|
||||
if frame_length is None:
|
||||
return window
|
||||
|
||||
if window_length > frame_length:
|
||||
raise ValueError(
|
||||
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
|
||||
)
|
||||
|
||||
padded_window = np.zeros(frame_length)
|
||||
offset = (frame_length - window_length) // 2 if center else 0
|
||||
padded_window[offset : offset + window_length] = window
|
||||
return padded_window
|
||||
|
||||
|
||||
# TODO This method does not support batching yet as we are mainly focused on inference.
|
||||
def spectrogram(
|
||||
waveform: np.ndarray,
|
||||
window: np.ndarray,
|
||||
frame_length: int,
|
||||
hop_length: int,
|
||||
fft_length: Optional[int] = None,
|
||||
power: Optional[float] = 1.0,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
onesided: bool = True,
|
||||
preemphasis: Optional[float] = None,
|
||||
mel_filters: Optional[np.ndarray] = None,
|
||||
mel_floor: float = 1e-10,
|
||||
log_mel: Optional[str] = None,
|
||||
reference: float = 1.0,
|
||||
min_value: float = 1e-10,
|
||||
db_range: Optional[float] = None,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
|
||||
|
||||
This function can create the following kinds of spectrograms:
|
||||
|
||||
- amplitude spectrogram (`power = 1.0`)
|
||||
- power spectrogram (`power = 2.0`)
|
||||
- complex-valued spectrogram (`power = None`)
|
||||
- log spectrogram (use `log_mel` argument)
|
||||
- mel spectrogram (provide `mel_filters`)
|
||||
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
||||
|
||||
How this works:
|
||||
|
||||
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
||||
- hop_length` samples.
|
||||
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
||||
3. The DFT is taken of each windowed frame.
|
||||
4. The results are stacked into a spectrogram.
|
||||
|
||||
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
||||
|
||||
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
||||
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
||||
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
||||
|
||||
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
||||
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
||||
typically the next power of two.
|
||||
|
||||
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
|
||||
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
|
||||
can be constructed.
|
||||
|
||||
Args:
|
||||
waveform (`np.ndarray` of shape `(length,)`):
|
||||
The input waveform. This must be a single real-valued, mono waveform.
|
||||
window (`np.ndarray` of shape `(frame_length,)`):
|
||||
The windowing function to apply, including zero-padding if necessary. The actual window length may be
|
||||
shorter than `frame_length`, but we're assuming the array has already been zero-padded.
|
||||
frame_length (`int`):
|
||||
The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
|
||||
allow smaller sizes.
|
||||
hop_length (`int`):
|
||||
The stride between successive analysis frames in samples.
|
||||
fft_length (`int`, *optional*):
|
||||
The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
|
||||
For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
|
||||
complex numbers.
|
||||
center (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
|
||||
`t` will start at time `t * hop_length`.
|
||||
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
||||
Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
|
||||
(pad with edge values), `"reflect"` (pads with mirrored values).
|
||||
onesided (`bool`, *optional*, defaults to `True`):
|
||||
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
||||
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
||||
preemphasis (`float`, *optional*)
|
||||
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
||||
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
||||
The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
|
||||
mel_floor (`float`, *optional*, defaults to 1e-10):
|
||||
Minimum value of mel frequency banks.
|
||||
log_mel (`str`, *optional*):
|
||||
How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
|
||||
the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
|
||||
used when `power` is not `None`.
|
||||
reference (`float`, *optional*, defaults to 1.0):
|
||||
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
||||
the loudest part to 0 dB. Must be greater than zero.
|
||||
min_value (`float`, *optional*, defaults to `1e-10`):
|
||||
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
||||
`log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
|
||||
amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
|
||||
db_range (`float`, *optional*):
|
||||
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
||||
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
||||
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
||||
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
|
||||
`np.complex64`.
|
||||
|
||||
Returns:
|
||||
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
|
||||
`(num_mel_filters, length)` for a mel spectrogram.
|
||||
"""
|
||||
window_length = len(window)
|
||||
|
||||
if fft_length is None:
|
||||
fft_length = frame_length
|
||||
|
||||
if frame_length > fft_length:
|
||||
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
||||
|
||||
if window_length != frame_length:
|
||||
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
||||
|
||||
if hop_length <= 0:
|
||||
raise ValueError("hop_length must be greater than zero")
|
||||
|
||||
if waveform.ndim != 1:
|
||||
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
||||
|
||||
if np.iscomplexobj(waveform):
|
||||
raise ValueError("Complex-valued input waveforms are not currently supported")
|
||||
|
||||
# center pad the waveform
|
||||
if center:
|
||||
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
||||
waveform = np.pad(waveform, padding, mode=pad_mode)
|
||||
|
||||
# promote to float64, since np.fft uses float64 internally
|
||||
waveform = waveform.astype(np.float64)
|
||||
window = window.astype(np.float64)
|
||||
|
||||
# split waveform into frames of frame_length size
|
||||
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
|
||||
|
||||
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
||||
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
|
||||
|
||||
# rfft is faster than fft
|
||||
fft_func = np.fft.rfft if onesided else np.fft.fft
|
||||
buffer = np.zeros(fft_length)
|
||||
|
||||
timestep = 0
|
||||
for frame_idx in range(num_frames):
|
||||
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
||||
|
||||
if preemphasis is not None:
|
||||
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
|
||||
buffer[0] *= 1 - preemphasis
|
||||
|
||||
buffer[:frame_length] *= window
|
||||
|
||||
spectrogram[frame_idx] = fft_func(buffer)
|
||||
timestep += hop_length
|
||||
|
||||
# note: ** is much faster than np.power
|
||||
if power is not None:
|
||||
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
||||
|
||||
spectrogram = spectrogram.T
|
||||
|
||||
if mel_filters is not None:
|
||||
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
|
||||
|
||||
if power is not None and log_mel is not None:
|
||||
if log_mel == "log":
|
||||
spectrogram = np.log(spectrogram)
|
||||
elif log_mel == "log10":
|
||||
spectrogram = np.log10(spectrogram)
|
||||
elif log_mel == "dB":
|
||||
if power == 1.0:
|
||||
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
|
||||
elif power == 2.0:
|
||||
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
|
||||
else:
|
||||
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
||||
else:
|
||||
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
||||
|
||||
spectrogram = np.asarray(spectrogram, dtype)
|
||||
|
||||
return spectrogram
|
||||
|
||||
|
||||
def power_to_db(
|
||||
spectrogram: np.ndarray,
|
||||
reference: float = 1.0,
|
||||
min_value: float = 1e-10,
|
||||
db_range: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
|
||||
logarithm properties for numerical stability.
|
||||
|
||||
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
||||
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
||||
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
||||
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
||||
|
||||
Based on the implementation of `librosa.power_to_db`.
|
||||
|
||||
Args:
|
||||
spectrogram (`np.ndarray`):
|
||||
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
|
||||
reference (`float`, *optional*, defaults to 1.0):
|
||||
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
||||
the loudest part to 0 dB. Must be greater than zero.
|
||||
min_value (`float`, *optional*, defaults to `1e-10`):
|
||||
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
||||
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
||||
db_range (`float`, *optional*):
|
||||
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
||||
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: the spectrogram in decibels
|
||||
"""
|
||||
if reference <= 0.0:
|
||||
raise ValueError("reference must be greater than zero")
|
||||
if min_value <= 0.0:
|
||||
raise ValueError("min_value must be greater than zero")
|
||||
|
||||
reference = max(min_value, reference)
|
||||
|
||||
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
||||
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
||||
|
||||
if db_range is not None:
|
||||
if db_range <= 0.0:
|
||||
raise ValueError("db_range must be greater than zero")
|
||||
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
||||
|
||||
return spectrogram
|
||||
|
||||
|
||||
def amplitude_to_db(
|
||||
spectrogram: np.ndarray,
|
||||
reference: float = 1.0,
|
||||
min_value: float = 1e-5,
|
||||
db_range: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
|
||||
basic logarithm properties for numerical stability.
|
||||
|
||||
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
||||
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
||||
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
||||
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
||||
|
||||
Args:
|
||||
spectrogram (`np.ndarray`):
|
||||
The input amplitude (mel) spectrogram.
|
||||
reference (`float`, *optional*, defaults to 1.0):
|
||||
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
||||
the loudest part to 0 dB. Must be greater than zero.
|
||||
min_value (`float`, *optional*, defaults to `1e-5`):
|
||||
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
||||
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
||||
db_range (`float`, *optional*):
|
||||
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
||||
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: the spectrogram in decibels
|
||||
"""
|
||||
if reference <= 0.0:
|
||||
raise ValueError("reference must be greater than zero")
|
||||
if min_value <= 0.0:
|
||||
raise ValueError("min_value must be greater than zero")
|
||||
|
||||
reference = max(min_value, reference)
|
||||
|
||||
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
||||
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
||||
|
||||
if db_range is not None:
|
||||
if db_range <= 0.0:
|
||||
raise ValueError("db_range must be greater than zero")
|
||||
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
||||
|
||||
return spectrogram
|
||||
|
||||
|
||||
### deprecated functions below this line ###
|
||||
|
||||
|
||||
def get_mel_filter_banks(
|
||||
@ -136,116 +564,21 @@ def get_mel_filter_banks(
|
||||
norm: Optional[str] = None,
|
||||
mel_scale: str = "htk",
|
||||
) -> np.array:
|
||||
"""
|
||||
Create a frequency bin conversion matrix used to obtain the Mel Spectrogram. This is called a *mel filter bank*,
|
||||
and various implementation exist, which differ in the number of filters, the shape of the filters, the way the
|
||||
filters are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
|
||||
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
|
||||
This code is heavily inspired from the *torchaudio* implementation, see
|
||||
[here](https://pytorch.org/audio/stable/transforms.html) for more details.
|
||||
warnings.warn(
|
||||
"The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
return mel_filter_bank(
|
||||
num_frequency_bins=nb_frequency_bins,
|
||||
num_mel_filters=nb_mel_filters,
|
||||
min_frequency=frequency_min,
|
||||
max_frequency=frequency_max,
|
||||
sampling_rate=sample_rate,
|
||||
norm=norm,
|
||||
mel_scale=mel_scale,
|
||||
)
|
||||
|
||||
|
||||
Tips:
|
||||
- Different banks of Mel filters were introduced in the litterature. The following variation are supported:
|
||||
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHertz
|
||||
and a speech bandwidth of `[0, 4600]` Hertz
|
||||
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a
|
||||
speech bandwidth `[0, 8000]` Hertz (sampling rate ≥ 16 kHertz).
|
||||
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate
|
||||
of 16 kHertz, and speech bandwidth [133, 6854] Hertz. This version also includes an area normalization.
|
||||
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes sampling
|
||||
rate of 12.5 kHertz and speech bandwidth [0, 6250] Hertz
|
||||
- The default parameters of `torchaudio`'s mel filterbanks implement the `"htk"` filers while `torchlibrosa`
|
||||
uses the `"slaney"` implementation.
|
||||
|
||||
Args:
|
||||
nb_frequency_bins (`int`):
|
||||
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
||||
nb_mel_filters (`int`):
|
||||
Number of Mel filers to generate.
|
||||
frequency_min (`float`):
|
||||
Minimum frequency of interest(Hertz).
|
||||
frequency_max (`float`):
|
||||
Maximum frequency of interest(Hertz).
|
||||
sample_rate (`int`):
|
||||
Sample rate of the audio waveform.
|
||||
norm (`str`, *optional*):
|
||||
If "slaney", divide the triangular Mel weights by the width of the mel band (area normalization).
|
||||
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
||||
Scale to use: `"htk"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Triangular filter banks (fb matrix) of shape (`nb_frequency_bins`, `nb_mel_filters`). This matrix
|
||||
is a projection matrix to go from a spectrogram to a Mel Spectrogram.
|
||||
|
||||
"""
|
||||
|
||||
if norm is not None and norm != "slaney":
|
||||
raise ValueError('norm must be one of None or "slaney"')
|
||||
|
||||
# freqency bins
|
||||
all_freqs = np.linspace(0, sample_rate // 2, nb_frequency_bins)
|
||||
|
||||
# Compute mim and max frequencies in mel scale
|
||||
m_min = hertz_to_mel(frequency_min, mel_scale=mel_scale)
|
||||
m_max = hertz_to_mel(frequency_max, mel_scale=mel_scale)
|
||||
|
||||
# create the centers of the triangular mel filters.
|
||||
m_pts = np.linspace(m_min, m_max, nb_mel_filters + 2)
|
||||
f_pts = mel_to_hertz(m_pts, mel_scale=mel_scale)
|
||||
|
||||
# create the filterbank
|
||||
filterbank = _create_triangular_filterbank(all_freqs, f_pts)
|
||||
|
||||
if norm is not None and norm == "slaney":
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (f_pts[2 : nb_mel_filters + 2] - f_pts[:nb_mel_filters])
|
||||
filterbank *= np.expand_dims(enorm, 0)
|
||||
|
||||
if (filterbank.max(axis=0) == 0.0).any():
|
||||
warnings.warn(
|
||||
"At least one mel filterbank has all zero values. "
|
||||
f"The value for `nb_mel_filters` ({nb_mel_filters}) may be set too high. "
|
||||
f"Or, the value for `nb_frequency_bins` ({nb_frequency_bins}) may be set too low."
|
||||
)
|
||||
|
||||
return filterbank
|
||||
|
||||
|
||||
def power_to_db(mel_spectrogram, top_db=None, a_min=1e-10, ref=1.0):
|
||||
"""
|
||||
Convert a mel spectrogram from power to db scale, this function is the numpy implementation of librosa.power_to_lb.
|
||||
It computes `10 * log10(mel_spectrogram / ref)`, using basic log properties for stability.
|
||||
|
||||
Tips:
|
||||
- The motivation behind applying the log function on the mel spectrogram is that humans do not hear loudness on
|
||||
a
|
||||
linear scale. Generally to double the percieved volume of a sound we need to put 8 times as much energy into
|
||||
it.
|
||||
- This means that large variations in energy may not sound all that different if the sound is loud to begin
|
||||
with. This compression operation makes the mel features match more closely what humans actually hear.
|
||||
|
||||
Args:
|
||||
mel_spectrogram (`np.array`):
|
||||
Input mel spectrogram.
|
||||
top_db (`int`, *optional*):
|
||||
The maximum decibel value.
|
||||
a_min (`int`, *optional*, default to 1e-10):
|
||||
Minimum value to use when cliping the mel spectrogram.
|
||||
ref (`float`, *optional*, default to 1.0):
|
||||
Maximum reference value used to scale the mel_spectrogram.
|
||||
|
||||
"""
|
||||
log_spec = 10 * np.log10(np.clip(mel_spectrogram, a_min=a_min, a_max=None))
|
||||
log_spec -= 10.0 * np.log10(np.maximum(a_min, ref))
|
||||
if top_db is not None:
|
||||
if top_db < 0:
|
||||
raise ValueError("top_db must be non-negative")
|
||||
log_spec = np.clip(log_spec, min=np.maximum(log_spec) - top_db, max=np.inf)
|
||||
return log_spec
|
||||
|
||||
|
||||
# TODO @ArthurZucker: This method does not support batching yet as we are mainly focus on inference.
|
||||
def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
|
||||
"""
|
||||
In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
|
||||
@ -270,6 +603,10 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int =
|
||||
framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
|
||||
The framed waveforms that can be fed to `np.fft`.
|
||||
"""
|
||||
warnings.warn(
|
||||
"The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
frames = []
|
||||
for i in range(0, waveform.shape[0] + 1, hop_length):
|
||||
if center:
|
||||
@ -298,9 +635,6 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int =
|
||||
return frames
|
||||
|
||||
|
||||
# TODO @ArthurZucker: This method does not support batching yet as we are mainly focus on inference.
|
||||
|
||||
|
||||
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
|
||||
"""
|
||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
|
||||
@ -337,6 +671,10 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int =
|
||||
spectrogram (`np.ndarray`):
|
||||
A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
|
||||
"""
|
||||
warnings.warn(
|
||||
"The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
frame_size = frames.shape[1]
|
||||
|
||||
if fft_window_size is None:
|
||||
@ -355,5 +693,5 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int =
|
||||
np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
|
||||
else:
|
||||
fft_signal[:frame_size] = frame
|
||||
spectrogram[f] = fft(fft_signal, axis=0)[:nb_frequency_bins]
|
||||
spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
|
||||
return spectrogram.T
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...audio_utils import fram_wave, get_mel_filter_banks, power_to_db, stft
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import TensorType, logging
|
||||
@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.sampling_rate = sampling_rate
|
||||
self.frequency_min = frequency_min
|
||||
self.frequency_max = frequency_max
|
||||
self.mel_filters = get_mel_filter_banks(
|
||||
nb_frequency_bins=self.nb_frequency_bins,
|
||||
nb_mel_filters=feature_size,
|
||||
frequency_min=frequency_min,
|
||||
frequency_max=frequency_max,
|
||||
sample_rate=sampling_rate,
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=self.nb_frequency_bins,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=frequency_min,
|
||||
max_frequency=frequency_max,
|
||||
sampling_rate=sampling_rate,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
self.mel_filters_slaney = get_mel_filter_banks(
|
||||
nb_frequency_bins=self.nb_frequency_bins,
|
||||
nb_mel_filters=feature_size,
|
||||
frequency_min=frequency_min,
|
||||
frequency_max=frequency_max,
|
||||
sample_rate=sampling_rate,
|
||||
self.mel_filters_slaney = mel_filter_bank(
|
||||
num_frequency_bins=self.nb_frequency_bins,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=frequency_min,
|
||||
max_frequency=frequency_max,
|
||||
sampling_rate=sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided `waveform` using the `hanning` window. In CLAP, two different
|
||||
filter banks are used depending on the truncation pattern:
|
||||
- `self.mel_filters`: they correspond to the defaults parameters of `torchaduio` which can be obtained from
|
||||
Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
|
||||
banks are used depending on the truncation pattern:
|
||||
- `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
|
||||
calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
|
||||
is set to `"fusion"`.
|
||||
- `self.mel_filteres_slaney` : they correspond to the defaults parameters of `torchlibrosa` which used
|
||||
- `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
|
||||
`librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
|
||||
implementation when the truncation mode is not `"fusion"`.
|
||||
"""
|
||||
window = np.hanning(self.fft_window_size + 1)[:-1]
|
||||
frames = fram_wave(waveform, self.hop_length, self.fft_window_size)
|
||||
spectrogram = stft(frames, window, fft_window_size=self.fft_window_size)
|
||||
|
||||
magnitudes = np.abs(spectrogram) ** 2
|
||||
mel_spectrogram = np.matmul(mel_filters.T, magnitudes)
|
||||
log_mel_spectrogram = power_to_db(mel_spectrogram).T
|
||||
log_mel_spectrogram = np.asarray(log_mel_spectrogram, np.float32)
|
||||
return log_mel_spectrogram
|
||||
log_mel_spectrogram = spectrogram(
|
||||
waveform,
|
||||
window_function(self.fft_window_size, "hann"),
|
||||
frame_length=self.fft_window_size,
|
||||
hop_length=self.hop_length,
|
||||
power=2.0,
|
||||
mel_filters=mel_filters,
|
||||
log_mel="dB",
|
||||
)
|
||||
return log_mel_spectrogram.T
|
||||
|
||||
def _random_mel_fusion(self, mel, total_frames, chunk_frames):
|
||||
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
|
||||
|
@ -20,9 +20,8 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from packaging import version
|
||||
|
||||
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...file_utils import PaddingStrategy, TensorType
|
||||
@ -31,13 +30,6 @@ from ...utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
parsed_torchaudio_version_base = version.parse(version.parse(torchaudio.__version__).base_version)
|
||||
if not parsed_torchaudio_version_base >= version.parse("0.10"):
|
||||
logger.warning(
|
||||
f"You are using torchaudio=={torchaudio.__version__}, but torchaudio>=0.10.0 is required to use "
|
||||
"MCTCTFeatureExtractor. This requires torch>=1.10.0. Please upgrade torch and torchaudio."
|
||||
)
|
||||
|
||||
|
||||
class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
||||
r"""
|
||||
@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.sample_size = win_length * sampling_rate // 1000
|
||||
self.sample_stride = hop_length * sampling_rate // 1000
|
||||
|
||||
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
|
||||
self.n_fft = optimal_fft_length(self.sample_size)
|
||||
self.n_freqs = (self.n_fft // 2) + 1
|
||||
|
||||
@staticmethod
|
||||
def _num_frames_calc(in_size, frame_size, frame_stride):
|
||||
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
|
||||
|
||||
@staticmethod
|
||||
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
|
||||
scale = frame_signal_scale
|
||||
frames = np.zeros(n_frames * window_length)
|
||||
for frame_idx in range(n_frames):
|
||||
start = frame_idx * window_length
|
||||
end = (frame_idx + 1) * window_length
|
||||
wave_start = frame_idx * sample_stride
|
||||
wave_end = frame_idx * sample_stride + window_length
|
||||
frames[start:end] = scale * one_waveform[wave_start:wave_end]
|
||||
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
|
||||
if frames.size % window_length != 0:
|
||||
raise ValueError(
|
||||
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
|
||||
f" window_length={window_length}."
|
||||
)
|
||||
|
||||
n_frames = frames.size // window_length
|
||||
for frame_idx in range(n_frames, 0, -1):
|
||||
start = (frame_idx - 1) * window_length
|
||||
end = frame_idx * window_length - 1
|
||||
frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
|
||||
frames[start] *= 1 - preemphasis_coeff
|
||||
|
||||
@staticmethod
|
||||
def _windowing(frames, window_length, window):
|
||||
if frames.size % window_length != 0:
|
||||
raise ValueError(
|
||||
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
|
||||
f" window_length={window_length}."
|
||||
)
|
||||
|
||||
shaped = frames.reshape(-1, window_length)
|
||||
shaped = window * shaped
|
||||
return shaped
|
||||
|
||||
@staticmethod
|
||||
def _dft(frames, K, n_frames, n_samples, n_fft):
|
||||
dft = np.zeros([n_frames, K])
|
||||
|
||||
for frame in range(n_frames):
|
||||
begin = frame * n_samples
|
||||
|
||||
inwards_buffer = frames[begin : begin + n_samples]
|
||||
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
|
||||
out = np.fft.rfft(inwards_buffer)
|
||||
|
||||
dft[frame] = np.abs(out[:K])
|
||||
|
||||
return dft
|
||||
|
||||
def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
|
||||
"""
|
||||
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
|
||||
@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
window = window.numpy()
|
||||
|
||||
fbanks = torchaudio.functional.melscale_fbanks(
|
||||
n_freqs=self.n_freqs,
|
||||
f_min=0.0, # change this to zeros
|
||||
f_max=self.sampling_rate / 2.0,
|
||||
n_mels=self.feature_size,
|
||||
sample_rate=self.sampling_rate,
|
||||
fbanks = mel_filter_bank(
|
||||
num_frequency_bins=self.n_freqs,
|
||||
num_mel_filters=self.feature_size,
|
||||
min_frequency=0.0,
|
||||
max_frequency=self.sampling_rate / 2.0,
|
||||
sampling_rate=self.sampling_rate,
|
||||
)
|
||||
|
||||
fbanks = fbanks.numpy()
|
||||
|
||||
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)
|
||||
|
||||
frames = self._frame_signal(
|
||||
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
|
||||
msfc_features = spectrogram(
|
||||
one_waveform * self.frame_signal_scale,
|
||||
window=window,
|
||||
frame_length=self.sample_size,
|
||||
hop_length=self.sample_stride,
|
||||
fft_length=self.n_fft,
|
||||
center=False,
|
||||
preemphasis=self.preemphasis_coeff,
|
||||
mel_filters=fbanks,
|
||||
mel_floor=self.mel_floor,
|
||||
log_mel="log",
|
||||
)
|
||||
|
||||
self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
|
||||
|
||||
frames = self._windowing(frames, self.sample_size, window)
|
||||
|
||||
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
|
||||
|
||||
# msfc_features = STFT * mel frequency banks.
|
||||
msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
|
||||
|
||||
# clamp feature values then log scale, as implemented in flashlight
|
||||
msfc_features = np.maximum(msfc_features, self.mel_floor)
|
||||
msfc_features = np.log(msfc_features)
|
||||
|
||||
return msfc_features
|
||||
return msfc_features.T
|
||||
|
||||
def _normalize_one(self, x, input_length, padding_value):
|
||||
# make sure we normalize float32 arrays
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...audio_utils import get_mel_filter_banks
|
||||
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
self.sample_size = win_length * sampling_rate // 1000
|
||||
self.sample_stride = hop_length * sampling_rate // 1000
|
||||
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
|
||||
self.n_fft = optimal_fft_length(self.sample_size)
|
||||
self.n_freqs = (self.n_fft // 2) + 1
|
||||
|
||||
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
|
||||
self.window = window.numpy().astype(np.float64)
|
||||
|
||||
self.mel_filters = get_mel_filter_banks(
|
||||
nb_frequency_bins=self.n_freqs,
|
||||
nb_mel_filters=self.num_mel_bins,
|
||||
frequency_min=self.fmin,
|
||||
frequency_max=self.fmax,
|
||||
sample_rate=self.sampling_rate,
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=self.n_freqs,
|
||||
num_mel_filters=self.num_mel_bins,
|
||||
min_frequency=self.fmin,
|
||||
max_frequency=self.fmax,
|
||||
sampling_rate=self.sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
return normed_input_values
|
||||
|
||||
@staticmethod
|
||||
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Calculates the magnitude spectrogram over one waveform array.
|
||||
"""
|
||||
# center pad the waveform
|
||||
padding = [(int(fft_length // 2), int(fft_length // 2))]
|
||||
waveform = np.pad(waveform, padding, mode="reflect")
|
||||
waveform_size = waveform.size
|
||||
|
||||
# promote to float64, since np.fft uses float64 internally
|
||||
waveform = waveform.astype(np.float64)
|
||||
|
||||
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
|
||||
num_frequency_bins = (fft_length // 2) + 1
|
||||
spectrogram = np.empty((num_frames, num_frequency_bins))
|
||||
|
||||
start = 0
|
||||
for frame_idx in range(num_frames):
|
||||
frame = waveform[start : start + fft_length] * window
|
||||
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
|
||||
start += hop_length
|
||||
|
||||
return spectrogram
|
||||
|
||||
def _extract_mel_features(
|
||||
self,
|
||||
one_waveform: np.ndarray,
|
||||
@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
"""
|
||||
Extracts log-mel filterbank features for one waveform array (unbatched).
|
||||
"""
|
||||
if self.n_fft != self.sample_size:
|
||||
raise NotImplementedError(
|
||||
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
|
||||
)
|
||||
|
||||
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
|
||||
|
||||
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
|
||||
log_mel_spec = spectrogram(
|
||||
one_waveform,
|
||||
window=self.window,
|
||||
frame_length=self.sample_size,
|
||||
hop_length=self.sample_stride,
|
||||
fft_length=self.n_fft,
|
||||
mel_filters=self.mel_filters,
|
||||
mel_floor=self.mel_floor,
|
||||
log_mel="log10",
|
||||
)
|
||||
return log_mel_spec.T
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -18,8 +18,8 @@ from math import ceil
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.hop_length = sampling_rate // hop_length_to_sampling_rate
|
||||
self.sampling_rate = sampling_rate
|
||||
self.padding_value = padding_value
|
||||
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
|
||||
|
||||
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.get_mel_filters with 45.245640471924965->59.99247463746737
|
||||
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = 0.0
|
||||
max_mel = 59.99247463746737
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
mels = np.asanyarray(mels)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
# If we have vector data, vectorize
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
|
||||
mel_f = freqs
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
|
||||
return weights
|
||||
|
||||
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.fram_wave
|
||||
def fram_wave(self, waveform, center=True):
|
||||
"""
|
||||
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
|
||||
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
|
||||
new frame.
|
||||
|
||||
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
|
||||
"""
|
||||
frames = []
|
||||
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
||||
half_window = (self.n_fft - 1) // 2 + 1
|
||||
if center:
|
||||
start = i - half_window if i > half_window else 0
|
||||
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
|
||||
|
||||
frame = waveform[start:end]
|
||||
|
||||
if start == 0:
|
||||
padd_width = (-i + half_window, 0)
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
elif end == waveform.shape[0]:
|
||||
padd_width = (0, (i - waveform.shape[0] + half_window))
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
else:
|
||||
frame = waveform[i : i + self.n_fft]
|
||||
frame_width = frame.shape[0]
|
||||
if frame_width < waveform.shape[0]:
|
||||
frame = np.lib.pad(
|
||||
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
frames.append(frame)
|
||||
return np.stack(frames, 0)
|
||||
|
||||
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.stft
|
||||
def stft(self, frames, window):
|
||||
"""
|
||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
|
||||
results as `torch.stft`.
|
||||
"""
|
||||
frame_size = frames.shape[1]
|
||||
fft_size = self.n_fft
|
||||
|
||||
if fft_size is None:
|
||||
fft_size = frame_size
|
||||
|
||||
if fft_size < frame_size:
|
||||
raise ValueError("FFT size must greater or equal the frame size")
|
||||
# number of FFT bins to store
|
||||
num_fft_bins = (fft_size >> 1) + 1
|
||||
|
||||
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
||||
fft_signal = np.zeros(fft_size)
|
||||
|
||||
for f, frame in enumerate(frames):
|
||||
if window is not None:
|
||||
np.multiply(frame, window, out=fft_signal[:frame_size])
|
||||
else:
|
||||
fft_signal[:frame_size] = frame
|
||||
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return data.T
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + n_fft // 2,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=0.0,
|
||||
max_frequency=22050.0,
|
||||
sampling_rate=sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
).T
|
||||
|
||||
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch
|
||||
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
|
||||
implementation with 1e-5 tolerance.
|
||||
"""
|
||||
window = np.hanning(self.n_fft + 1)[:-1]
|
||||
|
||||
frames = self.fram_wave(waveform)
|
||||
stft = self.stft(frames, window=window)
|
||||
magnitudes = np.abs(stft[:, :-1]) ** 2
|
||||
|
||||
filters = self.mel_filters
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = 10.0 * np.log10(np.maximum(1e-10, mel_spec))
|
||||
log_spec -= 10.0 * np.log10(np.maximum(1e-10, 1.0))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 80.0)
|
||||
log_spec = spectrogram(
|
||||
waveform,
|
||||
window_function(self.n_fft, "hann"),
|
||||
frame_length=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
power=2.0,
|
||||
mel_filters=self.mel_filters.T,
|
||||
log_mel="dB",
|
||||
db_range=80.0,
|
||||
)
|
||||
log_spec = log_spec[:, :-1]
|
||||
log_spec = log_spec - 20.0
|
||||
log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0
|
||||
|
||||
return log_spec
|
||||
|
||||
def __call__(
|
||||
|
@ -19,8 +19,8 @@ import copy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import TensorType, logging
|
||||
@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
self.n_samples = chunk_length * sampling_rate
|
||||
self.nb_max_frames = self.n_samples // hop_length
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
|
||||
|
||||
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = 0.0
|
||||
max_mel = 45.245640471924965
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
|
||||
mels = np.asanyarray(mels)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
# If we have vector data, vectorize
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
|
||||
mel_f = freqs
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
|
||||
return weights
|
||||
|
||||
def fram_wave(self, waveform, center=True):
|
||||
"""
|
||||
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
|
||||
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
|
||||
new frame.
|
||||
|
||||
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
|
||||
"""
|
||||
frames = []
|
||||
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
||||
half_window = (self.n_fft - 1) // 2 + 1
|
||||
if center:
|
||||
start = i - half_window if i > half_window else 0
|
||||
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
|
||||
|
||||
frame = waveform[start:end]
|
||||
|
||||
if start == 0:
|
||||
padd_width = (-i + half_window, 0)
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
elif end == waveform.shape[0]:
|
||||
padd_width = (0, (i - waveform.shape[0] + half_window))
|
||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
||||
|
||||
else:
|
||||
frame = waveform[i : i + self.n_fft]
|
||||
frame_width = frame.shape[0]
|
||||
if frame_width < waveform.shape[0]:
|
||||
frame = np.lib.pad(
|
||||
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
frames.append(frame)
|
||||
return np.stack(frames, 0)
|
||||
|
||||
def stft(self, frames, window):
|
||||
"""
|
||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
|
||||
results as `torch.stft`.
|
||||
"""
|
||||
frame_size = frames.shape[1]
|
||||
fft_size = self.n_fft
|
||||
|
||||
if fft_size is None:
|
||||
fft_size = frame_size
|
||||
|
||||
if fft_size < frame_size:
|
||||
raise ValueError("FFT size must greater or equal the frame size")
|
||||
# number of FFT bins to store
|
||||
num_fft_bins = (fft_size >> 1) + 1
|
||||
|
||||
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
||||
fft_signal = np.zeros(fft_size)
|
||||
|
||||
for f, frame in enumerate(frames):
|
||||
if window is not None:
|
||||
np.multiply(frame, window, out=fft_signal[:frame_size])
|
||||
else:
|
||||
fft_signal[:frame_size] = frame
|
||||
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return data.T
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + n_fft // 2,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=0.0,
|
||||
max_frequency=8000.0,
|
||||
sampling_rate=sampling_rate,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
|
||||
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch
|
||||
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
|
||||
implementation with 1e-5 tolerance.
|
||||
"""
|
||||
window = np.hanning(self.n_fft + 1)[:-1]
|
||||
|
||||
frames = self.fram_wave(waveform)
|
||||
stft = self.stft(frames, window=window)
|
||||
magnitudes = np.abs(stft[:, :-1]) ** 2
|
||||
|
||||
filters = self.mel_filters
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
||||
log_spec = spectrogram(
|
||||
waveform,
|
||||
window_function(self.n_fft, "hann"),
|
||||
frame_length=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
power=2.0,
|
||||
mel_filters=self.mel_filters,
|
||||
log_mel="log10",
|
||||
)
|
||||
log_spec = log_spec[:, :-1]
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
return log_spec
|
||||
|
||||
@staticmethod
|
||||
|
@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feaure_extractor = ASTFeatureExtractor()
|
||||
input_values = feaure_extractor(input_speech, return_tensors="pt").input_values
|
||||
feature_extractor = ASTFeatureExtractor()
|
||||
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
|
||||
self.assertEquals(input_values.shape, (1, 1024, 128))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
|
@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class MCTCTFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
|
||||
|
||||
@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
|
||||
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
||||
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_integration(self):
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
[
|
||||
1.1280, 1.1319, 1.2744, 1.4369, 1.4328, 1.3671, 1.2889, 1.3046,
|
||||
1.4419, 0.8387, 0.2995, 0.0404, 0.1068, 0.0472, 0.3728, 1.3356,
|
||||
1.4491, 0.4770, 0.3997, 0.2776, 0.3184, -0.1243, -0.1170, -0.0828
|
||||
],
|
||||
[
|
||||
1.0826, 1.0565, 1.2110, 1.3886, 1.3416, 1.2009, 1.1894, 1.2707,
|
||||
1.5153, 0.7005, 0.4916, 0.4017, 0.3743, 0.1935, 0.4228, 1.1084,
|
||||
0.9768, 0.0608, 0.2044, 0.1723, 0.0433, -0.2360, -0.2478, -0.2643
|
||||
],
|
||||
[
|
||||
1.0590, 0.9923, 1.1185, 1.3309, 1.1971, 1.0067, 1.0080, 1.2036,
|
||||
1.5397, 1.0383, 0.7672, 0.7551, 0.4878, 0.8771, 0.7565, 0.8775,
|
||||
0.9042, 0.4595, 0.6157, 0.4954, 0.1857, 0.0307, 0.0199, 0.1033
|
||||
],
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
input_features = feature_extractor(input_speech, sampling_rate=16000, return_tensors="pt").input_features
|
||||
self.assertTrue(np.allclose(input_features[0, 100:103], expected, atol=1e-4))
|
||||
|
@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
||||
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_integration(self):
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
-1.5745, -1.7713, -1.7020, -1.6069, -1.2250, -1.1105, -0.9072, -0.8241,
|
||||
-1.2310, -0.8098, -0.3320, -0.4101, -0.7985, -0.4996, -0.8213, -0.9128,
|
||||
-1.0420, -1.1286, -1.0440, -0.7999, -0.8405, -1.2275, -1.5443, -1.4625,
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEquals(input_features.shape, (1, 584, 24))
|
||||
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
|
||||
|
@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = SpeechT5FeatureExtractor()
|
||||
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
|
||||
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
self.assertEquals(input_values.shape, (1, 93680))
|
||||
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
|
||||
|
||||
def test_integration_target(self):
|
||||
# fmt: off
|
||||
@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = SpeechT5FeatureExtractor()
|
||||
input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values
|
||||
self.assertEquals(input_values.shape, (1, 366, 80))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
|
@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
|
||||
def test_integration(self):
|
||||
input_speech = self._load_datasamples(1)
|
||||
feaure_extractor = TvltFeatureExtractor()
|
||||
audio_values = feaure_extractor(input_speech, return_tensors="pt").audio_values
|
||||
feature_extractor = TvltFeatureExtractor()
|
||||
audio_values = feature_extractor(input_speech, return_tensors="pt").audio_values
|
||||
|
||||
self.assertTrue(audio_values.shape, [1, 1, 192, 128])
|
||||
self.assertEquals(audio_values.shape, (1, 1, 192, 128))
|
||||
|
||||
expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]])
|
||||
self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4))
|
||||
|
@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feaure_extractor = WhisperFeatureExtractor()
|
||||
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
|
652
tests/utils/test_audio_utils.py
Normal file
652
tests/utils/test_audio_utils.py
Normal file
@ -0,0 +1,652 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers.audio_utils import (
|
||||
amplitude_to_db,
|
||||
hertz_to_mel,
|
||||
mel_filter_bank,
|
||||
mel_to_hertz,
|
||||
power_to_db,
|
||||
spectrogram,
|
||||
window_function,
|
||||
)
|
||||
|
||||
|
||||
class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
def test_hertz_to_mel(self):
|
||||
self.assertEqual(hertz_to_mel(0.0), 0.0)
|
||||
self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)
|
||||
|
||||
inputs = np.array([100, 200])
|
||||
expected = np.array([150.48910241, 283.22989816])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))
|
||||
|
||||
self.assertEqual(hertz_to_mel(0.0, "slaney"), 0.0)
|
||||
self.assertEqual(hertz_to_mel(100, "slaney"), 1.5)
|
||||
|
||||
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
hertz_to_mel(100, mel_scale=None)
|
||||
|
||||
def test_mel_to_hertz(self):
|
||||
self.assertEqual(mel_to_hertz(0.0), 0.0)
|
||||
self.assertAlmostEqual(mel_to_hertz(150.48910241), 100)
|
||||
|
||||
inputs = np.array([150.48910241, 283.22989816])
|
||||
expected = np.array([100, 200])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs), expected))
|
||||
|
||||
self.assertEqual(mel_to_hertz(0.0, "slaney"), 0.0)
|
||||
self.assertEqual(mel_to_hertz(1.5, "slaney"), 100)
|
||||
|
||||
inputs = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
|
||||
expected = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mel_to_hertz(100, mel_scale=None)
|
||||
|
||||
def test_mel_filter_bank_shape(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=513,
|
||||
num_mel_filters=13,
|
||||
min_frequency=100,
|
||||
max_frequency=4000,
|
||||
sampling_rate=16000,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=513,
|
||||
num_mel_filters=13,
|
||||
min_frequency=100,
|
||||
max_frequency=4000,
|
||||
sampling_rate=16000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
def test_mel_filter_bank_htk(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
num_mel_filters=4,
|
||||
min_frequency=0,
|
||||
max_frequency=2000,
|
||||
sampling_rate=4000,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
[0.0 , 0.0 , 0.0 , 0.0 ],
|
||||
[0.61454786, 0.0 , 0.0 , 0.0 ],
|
||||
[0.82511046, 0.17488954, 0.0 , 0.0 ],
|
||||
[0.35597035, 0.64402965, 0.0 , 0.0 ],
|
||||
[0.0 , 0.91360726, 0.08639274, 0.0 ],
|
||||
[0.0 , 0.55547007, 0.44452993, 0.0 ],
|
||||
[0.0 , 0.19733289, 0.80266711, 0.0 ],
|
||||
[0.0 , 0.0 , 0.87724349, 0.12275651],
|
||||
[0.0 , 0.0 , 0.6038449 , 0.3961551 ],
|
||||
[0.0 , 0.0 , 0.33044631, 0.66955369],
|
||||
[0.0 , 0.0 , 0.05704771, 0.94295229],
|
||||
[0.0 , 0.0 , 0.0 , 0.83483975],
|
||||
[0.0 , 0.0 , 0.0 , 0.62612982],
|
||||
[0.0 , 0.0 , 0.0 , 0.41741988],
|
||||
[0.0 , 0.0 , 0.0 , 0.20870994],
|
||||
[0.0 , 0.0 , 0.0 , 0.0 ]
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected))
|
||||
|
||||
def test_mel_filter_bank_slaney(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
num_mel_filters=4,
|
||||
min_frequency=0,
|
||||
max_frequency=2000,
|
||||
sampling_rate=4000,
|
||||
norm=None,
|
||||
mel_scale="slaney",
|
||||
)
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
[0.0 , 0.0 , 0.0 , 0.0 ],
|
||||
[0.39869419, 0.0 , 0.0 , 0.0 ],
|
||||
[0.79738839, 0.0 , 0.0 , 0.0 ],
|
||||
[0.80391742, 0.19608258, 0.0 , 0.0 ],
|
||||
[0.40522322, 0.59477678, 0.0 , 0.0 ],
|
||||
[0.00652903, 0.99347097, 0.0 , 0.0 ],
|
||||
[0.0 , 0.60796161, 0.39203839, 0.0 ],
|
||||
[0.0 , 0.20939631, 0.79060369, 0.0 ],
|
||||
[0.0 , 0.0 , 0.84685344, 0.15314656],
|
||||
[0.0 , 0.0 , 0.52418477, 0.47581523],
|
||||
[0.0 , 0.0 , 0.2015161 , 0.7984839 ],
|
||||
[0.0 , 0.0 , 0.0 , 0.9141874 ],
|
||||
[0.0 , 0.0 , 0.0 , 0.68564055],
|
||||
[0.0 , 0.0 , 0.0 , 0.4570937 ],
|
||||
[0.0 , 0.0 , 0.0 , 0.22854685],
|
||||
[0.0 , 0.0 , 0.0 , 0.0 ]
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected))
|
||||
|
||||
def test_mel_filter_bank_slaney_norm(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
num_mel_filters=4,
|
||||
min_frequency=0,
|
||||
max_frequency=2000,
|
||||
sampling_rate=4000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
)
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
[1.19217795e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
[2.38435591e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
[2.40387905e-03, 5.86232616e-04, 0.00000000e+00, 0.00000000e+00],
|
||||
[1.21170110e-03, 1.77821783e-03, 0.00000000e+00, 0.00000000e+00],
|
||||
[1.95231437e-05, 2.97020305e-03, 0.00000000e+00, 0.00000000e+00],
|
||||
[0.00000000e+00, 1.81763684e-03, 1.04857612e-03, 0.00000000e+00],
|
||||
[0.00000000e+00, 6.26036972e-04, 2.11460963e-03, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 2.26505954e-03, 3.07332945e-04],
|
||||
[0.00000000e+00, 0.00000000e+00, 1.40202503e-03, 9.54861093e-04],
|
||||
[0.00000000e+00, 0.00000000e+00, 5.38990521e-04, 1.60238924e-03],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.83458185e-03],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.37593638e-03],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.17290923e-04],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 4.58645462e-04],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected))
|
||||
|
||||
def test_window_function(self):
|
||||
window = window_function(16, "hann")
|
||||
self.assertEqual(len(window), 16)
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.0, 0.03806023, 0.14644661, 0.30865828, 0.5, 0.69134172, 0.85355339, 0.96193977,
|
||||
1.0, 0.96193977, 0.85355339, 0.69134172, 0.5, 0.30865828, 0.14644661, 0.03806023,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(window, expected))
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_spectrogram_impulse(self):
|
||||
waveform = np.zeros(40)
|
||||
waveform[9] = 1.0 # impulse shifted in time
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(12, "hann", frame_length=16),
|
||||
frame_length=16,
|
||||
hop_length=4,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (9, 11))
|
||||
|
||||
expected = np.array([[0.0, 0.0669873, 0.9330127, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
|
||||
self.assertTrue(np.allclose(spec, expected))
|
||||
|
||||
def test_spectrogram_integration_test(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann", frame_length=512),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,
|
||||
0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,
|
||||
0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,
|
||||
0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,
|
||||
0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,
|
||||
0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,
|
||||
0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,
|
||||
0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,
|
||||
0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,
|
||||
0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,
|
||||
0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,
|
||||
0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,
|
||||
0.0293578 , 0.03452379, 0.02194803, 0.01676056,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann"),
|
||||
frame_length=400,
|
||||
hop_length=128,
|
||||
fft_length=512,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
||||
|
||||
def test_spectrogram_center_padding(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(512, "hann"),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,
|
||||
0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,
|
||||
0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,
|
||||
0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,
|
||||
0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,
|
||||
0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,
|
||||
0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,
|
||||
0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,
|
||||
0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,
|
||||
0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,
|
||||
0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,
|
||||
0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,
|
||||
0.00217659, 0.00276204, 0.00260835, 0.00299299,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 0], expected))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(512, "hann"),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
center=True,
|
||||
pad_mode="constant",
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,
|
||||
0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,
|
||||
0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,
|
||||
0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,
|
||||
0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,
|
||||
0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,
|
||||
0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,
|
||||
0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,
|
||||
0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,
|
||||
0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,
|
||||
0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,
|
||||
0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,
|
||||
0.00788239, 0.00664407, 0.00824227, 0.00628301,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 0], expected))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(512, "hann"),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
center=False,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 728))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,
|
||||
0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,
|
||||
0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,
|
||||
0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,
|
||||
0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,
|
||||
0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,
|
||||
0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,
|
||||
0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,
|
||||
0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,
|
||||
0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,
|
||||
0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,
|
||||
0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,
|
||||
0.00811857, 0.00538216, 0.00685749, 0.00535275,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 0], expected))
|
||||
|
||||
def test_spectrogram_shapes(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann"),
|
||||
frame_length=400,
|
||||
hop_length=128,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (201, 732))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann"),
|
||||
frame_length=400,
|
||||
hop_length=128,
|
||||
power=1.0,
|
||||
center=False,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (201, 729))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann"),
|
||||
frame_length=400,
|
||||
hop_length=128,
|
||||
fft_length=512,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann", frame_length=512),
|
||||
frame_length=512,
|
||||
hop_length=64,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=False,
|
||||
)
|
||||
self.assertEqual(spec.shape, (512, 1464))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(512, "hann"),
|
||||
frame_length=512,
|
||||
hop_length=64,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=False,
|
||||
)
|
||||
self.assertEqual(spec.shape, (512, 1464))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(512, "hann"),
|
||||
frame_length=512,
|
||||
hop_length=512,
|
||||
power=1.0,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=False,
|
||||
)
|
||||
self.assertEqual(spec.shape, (512, 183))
|
||||
|
||||
def test_mel_spectrogram(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=513,
|
||||
num_mel_filters=13,
|
||||
min_frequency=100,
|
||||
max_frequency=4000,
|
||||
sampling_rate=16000,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(800, "hann", frame_length=1024),
|
||||
frame_length=1024,
|
||||
hop_length=128,
|
||||
power=2.0,
|
||||
)
|
||||
self.assertEqual(spec.shape, (513, 732))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(800, "hann", frame_length=1024),
|
||||
frame_length=1024,
|
||||
hop_length=128,
|
||||
power=2.0,
|
||||
mel_filters=mel_filters,
|
||||
)
|
||||
self.assertEqual(spec.shape, (13, 732))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,
|
||||
8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,
|
||||
7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,
|
||||
9.44153646e-04
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:, 300], expected))
|
||||
|
||||
def test_spectrogram_power(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann", frame_length=512),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
power=None,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertEqual(spec.dtype, np.complex64)
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.01452305+0.01820039j, -0.01737362-0.01641946j,
|
||||
0.0121028 +0.01565081j, -0.02794554-0.03021514j,
|
||||
0.04719803+0.04086519j, -0.04391563-0.02779365j,
|
||||
0.05682834+0.01571325j, -0.08604821-0.02023657j,
|
||||
0.07497991+0.0186641j , -0.06366091-0.00922475j,
|
||||
0.11003416+0.0114788j , -0.13677941-0.01523552j,
|
||||
0.10934535-0.00117226j, -0.11635598+0.02551187j,
|
||||
0.14708674-0.03469823j, -0.1328196 +0.06034218j,
|
||||
0.12667368-0.13973421j, -0.14764774+0.18912019j,
|
||||
0.10235471-0.12181523j, -0.00773012+0.04730498j,
|
||||
-0.01487191-0.07312611j, -0.02739162+0.09619419j,
|
||||
0.02895459-0.05398273j, 0.01198589+0.05276592j,
|
||||
-0.02117299-0.10123465j, 0.00666388+0.09526499j,
|
||||
-0.01672773-0.05649684j, 0.02723125+0.05939891j,
|
||||
-0.01879361-0.062954j , 0.03686557+0.04568823j,
|
||||
-0.07394181-0.07949649j, 0.06238583+0.13905765j,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[64:96, 321], expected))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann", frame_length=512),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
power=1.0,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertEqual(spec.dtype, np.float64)
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
0.02328461, 0.02390484, 0.01978448, 0.04115711, 0.0624309 ,
|
||||
0.05197181, 0.05896072, 0.08839577, 0.07726794, 0.06432579,
|
||||
0.11063128, 0.13762532, 0.10935163, 0.11911998, 0.15112405,
|
||||
0.14588428, 0.18860507, 0.23992978, 0.15910825, 0.04793241,
|
||||
0.07462307, 0.10001811, 0.06125769, 0.05411011, 0.10342509,
|
||||
0.09549777, 0.05892122, 0.06534349, 0.06569936, 0.05870678,
|
||||
0.10856833, 0.1524107 , 0.11463385, 0.05766969, 0.12385171,
|
||||
0.14472842, 0.11978184, 0.10353675, 0.07244056, 0.03461861,
|
||||
0.02624896, 0.02227475, 0.01238363, 0.00885281, 0.0110049 ,
|
||||
0.00807005, 0.01033663, 0.01703181, 0.01445856, 0.00585615,
|
||||
0.0132431 , 0.02754132, 0.01524478, 0.0204908 , 0.07453328,
|
||||
0.10716327, 0.07195779, 0.08816078, 0.18340898, 0.16449876,
|
||||
0.12322842, 0.1621659 , 0.12334293, 0.06033659,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[64:128, 321], expected))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "hann", frame_length=512),
|
||||
frame_length=512,
|
||||
hop_length=128,
|
||||
power=2.0,
|
||||
)
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertEqual(spec.dtype, np.float64)
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([
|
||||
5.42173162e-04, 5.71441371e-04, 3.91425507e-04, 1.69390778e-03,
|
||||
3.89761780e-03, 2.70106923e-03, 3.47636663e-03, 7.81381316e-03,
|
||||
5.97033510e-03, 4.13780799e-03, 1.22392802e-02, 1.89407300e-02,
|
||||
1.19577805e-02, 1.41895693e-02, 2.28384770e-02, 2.12822221e-02,
|
||||
3.55718732e-02, 5.75663000e-02, 2.53154356e-02, 2.29751552e-03,
|
||||
5.56860259e-03, 1.00036217e-02, 3.75250424e-03, 2.92790355e-03,
|
||||
1.06967501e-02, 9.11982451e-03, 3.47171025e-03, 4.26977174e-03,
|
||||
4.31640586e-03, 3.44648538e-03, 1.17870830e-02, 2.32290216e-02,
|
||||
1.31409196e-02, 3.32579296e-03, 1.53392460e-02, 2.09463164e-02,
|
||||
1.43476883e-02, 1.07198600e-02, 5.24763530e-03, 1.19844836e-03,
|
||||
6.89007982e-04, 4.96164430e-04, 1.53354369e-04, 7.83722571e-05,
|
||||
1.21107812e-04, 6.51257360e-05, 1.06845939e-04, 2.90082477e-04,
|
||||
2.09049831e-04, 3.42945241e-05, 1.75379610e-04, 7.58524227e-04,
|
||||
2.32403356e-04, 4.19872697e-04, 5.55520924e-03, 1.14839673e-02,
|
||||
5.17792348e-03, 7.77232368e-03, 3.36388536e-02, 2.70598419e-02,
|
||||
1.51852425e-02, 2.62977779e-02, 1.52134784e-02, 3.64050455e-03,
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[64:128, 321], expected))
|
||||
|
||||
def test_power_to_db(self):
|
||||
spectrogram = np.zeros((2, 3))
|
||||
spectrogram[0, 0] = 2.0
|
||||
spectrogram[0, 1] = 0.5
|
||||
spectrogram[0, 2] = 0.707
|
||||
spectrogram[1, 1] = 1.0
|
||||
|
||||
output = power_to_db(spectrogram, reference=1.0)
|
||||
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-100.0, 0.0, -100.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = power_to_db(spectrogram, reference=2.0)
|
||||
expected = np.array([[0.0, -6.02059991, -4.51610582], [-103.01029996, -3.01029996, -103.01029996]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = power_to_db(spectrogram, min_value=1e-6)
|
||||
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-60.0, 0.0, -60.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = power_to_db(spectrogram, db_range=80)
|
||||
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-76.98970004, 0.0, -76.98970004]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = power_to_db(spectrogram, reference=2.0, db_range=80)
|
||||
expected = np.array([[0.0, -6.02059991, -4.51610582], [-80.0, -3.01029996, -80.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = power_to_db(spectrogram, reference=2.0, min_value=1e-6, db_range=80)
|
||||
expected = np.array([[0.0, -6.02059991, -4.51610582], [-63.01029996, -3.01029996, -63.01029996]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
power_to_db(spectrogram, reference=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
power_to_db(spectrogram, min_value=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
power_to_db(spectrogram, db_range=-80)
|
||||
|
||||
def test_amplitude_to_db(self):
|
||||
spectrogram = np.zeros((2, 3))
|
||||
spectrogram[0, 0] = 2.0
|
||||
spectrogram[0, 1] = 0.5
|
||||
spectrogram[0, 2] = 0.707
|
||||
spectrogram[1, 1] = 1.0
|
||||
|
||||
output = amplitude_to_db(spectrogram, reference=1.0)
|
||||
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-100.0, 0.0, -100.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = amplitude_to_db(spectrogram, reference=2.0)
|
||||
expected = np.array([[0.0, -12.04119983, -9.03221164], [-106.02059991, -6.02059991, -106.02059991]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = amplitude_to_db(spectrogram, min_value=1e-3)
|
||||
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-60.0, 0.0, -60.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = amplitude_to_db(spectrogram, db_range=80)
|
||||
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-73.97940009, 0.0, -73.97940009]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = amplitude_to_db(spectrogram, reference=2.0, db_range=80)
|
||||
expected = np.array([[0.0, -12.04119983, -9.03221164], [-80.0, -6.02059991, -80.0]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
output = amplitude_to_db(spectrogram, reference=2.0, min_value=1e-3, db_range=80)
|
||||
expected = np.array([[0.0, -12.04119983, -9.03221164], [-66.02059991, -6.02059991, -66.02059991]])
|
||||
self.assertTrue(np.allclose(output, expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
amplitude_to_db(spectrogram, reference=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
amplitude_to_db(spectrogram, min_value=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
amplitude_to_db(spectrogram, db_range=-80)
|
Loading…
Reference in New Issue
Block a user