[phi-4] use mel filters from audio utils (#36966)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* use mel_filter_bank from audio utils

* Apply style fixes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
eustlb 2025-06-19 05:35:32 +02:00 committed by GitHub
parent f7b21822e3
commit 11738f8537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,7 +20,7 @@ from typing import Optional, Union
import numpy as np
from ...audio_utils import AudioInput
from ...audio_utils import AudioInput, mel_filter_bank
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...image_processing_utils import BatchFeature
from ...utils import TensorType, is_torch_available, logging
@ -33,66 +33,6 @@ if is_torch_available():
logger = logging.get_logger(__name__)
# TODO: @eustlb, remove this once #36603 is merged.
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
Args:
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
n_fft (int): FFT size. int > 0 [scalar]
n_mel (int): Mel filter size. int > 0 [scalar]
fmin (float): lowest frequency (in Hz). If None use 0.0.
float >= 0 [scalar]
fmax: highest frequency (in Hz). If None use sample_rate / 2.
float >= 0 [scalar]
Returns
out (numpy.ndarray): Mel transform matrix
[shape=(n_mels, 1 + n_fft/2)]
"""
bank_width = int(n_fft // 2 + 1)
if fmax is None:
fmax = sample_rate / 2
if fmin is None:
fmin = 0
assert fmin >= 0, "fmin cannot be negative"
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
def mel(f):
return 1127.0 * np.log(1.0 + f / 700.0)
def bin2mel(fft_bin):
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
def f2bin(f):
return int((f * n_fft / sample_rate) + 0.5)
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
klo = f2bin(fmin) + 1
khi = f2bin(fmax)
khi = max(khi, klo)
# Spec 2: SpeechLib uses triangles in Mel space
mlo = mel(fmin)
mhi = mel(fmax)
m_centers = np.linspace(mlo, mhi, n_mels + 2)
ms = (mhi - mlo) / (n_mels + 1)
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
for m in range(0, n_mels):
left = m_centers[m]
center = m_centers[m + 1]
right = m_centers[m + 2]
for fft_bin in range(klo, khi):
mbin = bin2mel(fft_bin)
if left < mbin < right:
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
return matrix
class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
@ -123,19 +63,15 @@ class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
self.audio_downsample_rate = audio_downsample_rate
self.audio_feat_stride = audio_feat_stride
# TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged.
# self.mel_filters = mel_filter_bank(
# num_frequency_bins=self.n_fft // 2 + 1,
# num_mel_filters=self.feature_size,
# min_frequency=mel_min_frequency,
# max_frequency=mel_max_frequency,
# sampling_rate=self.sampling_rate,
# triangularize_in_mel_space=True,
# mel_scale="kaldi",
# )
self.mel_filters = speechlib_mel(
self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency
).T
self.mel_filters = mel_filter_bank(
num_frequency_bins=self.n_fft // 2 + 1,
num_mel_filters=self.feature_size,
min_frequency=mel_min_frequency,
max_frequency=mel_max_frequency,
sampling_rate=self.sampling_rate,
triangularize_in_mel_space=True,
mel_scale="kaldi",
)
def __call__(
self,