mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
[phi-4] use mel filters from audio utils (#36966)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* use mel_filter_bank from audio utils * Apply style fixes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
f7b21822e3
commit
11738f8537
@ -20,7 +20,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...audio_utils import AudioInput
|
from ...audio_utils import AudioInput, mel_filter_bank
|
||||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...utils import TensorType, is_torch_available, logging
|
from ...utils import TensorType, is_torch_available, logging
|
||||||
@ -33,66 +33,6 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: @eustlb, remove this once #36603 is merged.
|
|
||||||
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
|
||||||
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
|
|
||||||
n_fft (int): FFT size. int > 0 [scalar]
|
|
||||||
n_mel (int): Mel filter size. int > 0 [scalar]
|
|
||||||
fmin (float): lowest frequency (in Hz). If None use 0.0.
|
|
||||||
float >= 0 [scalar]
|
|
||||||
fmax: highest frequency (in Hz). If None use sample_rate / 2.
|
|
||||||
float >= 0 [scalar]
|
|
||||||
|
|
||||||
Returns
|
|
||||||
out (numpy.ndarray): Mel transform matrix
|
|
||||||
[shape=(n_mels, 1 + n_fft/2)]
|
|
||||||
"""
|
|
||||||
|
|
||||||
bank_width = int(n_fft // 2 + 1)
|
|
||||||
if fmax is None:
|
|
||||||
fmax = sample_rate / 2
|
|
||||||
if fmin is None:
|
|
||||||
fmin = 0
|
|
||||||
assert fmin >= 0, "fmin cannot be negative"
|
|
||||||
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
|
|
||||||
|
|
||||||
def mel(f):
|
|
||||||
return 1127.0 * np.log(1.0 + f / 700.0)
|
|
||||||
|
|
||||||
def bin2mel(fft_bin):
|
|
||||||
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
|
||||||
|
|
||||||
def f2bin(f):
|
|
||||||
return int((f * n_fft / sample_rate) + 0.5)
|
|
||||||
|
|
||||||
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
|
|
||||||
klo = f2bin(fmin) + 1
|
|
||||||
khi = f2bin(fmax)
|
|
||||||
|
|
||||||
khi = max(khi, klo)
|
|
||||||
|
|
||||||
# Spec 2: SpeechLib uses triangles in Mel space
|
|
||||||
mlo = mel(fmin)
|
|
||||||
mhi = mel(fmax)
|
|
||||||
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
|
||||||
ms = (mhi - mlo) / (n_mels + 1)
|
|
||||||
|
|
||||||
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
|
||||||
for m in range(0, n_mels):
|
|
||||||
left = m_centers[m]
|
|
||||||
center = m_centers[m + 1]
|
|
||||||
right = m_centers[m + 2]
|
|
||||||
for fft_bin in range(klo, khi):
|
|
||||||
mbin = bin2mel(fft_bin)
|
|
||||||
if left < mbin < right:
|
|
||||||
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
|
||||||
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
|
|
||||||
class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
|
class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
|
||||||
model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
|
model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
|
||||||
|
|
||||||
@ -123,19 +63,15 @@ class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
self.audio_downsample_rate = audio_downsample_rate
|
self.audio_downsample_rate = audio_downsample_rate
|
||||||
self.audio_feat_stride = audio_feat_stride
|
self.audio_feat_stride = audio_feat_stride
|
||||||
|
|
||||||
# TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged.
|
self.mel_filters = mel_filter_bank(
|
||||||
# self.mel_filters = mel_filter_bank(
|
num_frequency_bins=self.n_fft // 2 + 1,
|
||||||
# num_frequency_bins=self.n_fft // 2 + 1,
|
num_mel_filters=self.feature_size,
|
||||||
# num_mel_filters=self.feature_size,
|
min_frequency=mel_min_frequency,
|
||||||
# min_frequency=mel_min_frequency,
|
max_frequency=mel_max_frequency,
|
||||||
# max_frequency=mel_max_frequency,
|
sampling_rate=self.sampling_rate,
|
||||||
# sampling_rate=self.sampling_rate,
|
triangularize_in_mel_space=True,
|
||||||
# triangularize_in_mel_space=True,
|
mel_scale="kaldi",
|
||||||
# mel_scale="kaldi",
|
)
|
||||||
# )
|
|
||||||
self.mel_filters = speechlib_mel(
|
|
||||||
self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency
|
|
||||||
).T
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user