diff --git a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py index b7111fe0cac..71ada3a8c62 100644 --- a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py @@ -20,7 +20,7 @@ from typing import Optional, Union import numpy as np -from ...audio_utils import AudioInput +from ...audio_utils import AudioInput, mel_filter_bank from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...image_processing_utils import BatchFeature from ...utils import TensorType, is_torch_available, logging @@ -33,66 +33,6 @@ if is_torch_available(): logger = logging.get_logger(__name__) -# TODO: @eustlb, remove this once #36603 is merged. -def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): - """Create a Mel filter-bank the same as SpeechLib FbankFC. - - Args: - sample_rate (int): Sample rate in Hz. number > 0 [scalar] - n_fft (int): FFT size. int > 0 [scalar] - n_mel (int): Mel filter size. int > 0 [scalar] - fmin (float): lowest frequency (in Hz). If None use 0.0. - float >= 0 [scalar] - fmax: highest frequency (in Hz). If None use sample_rate / 2. - float >= 0 [scalar] - - Returns - out (numpy.ndarray): Mel transform matrix - [shape=(n_mels, 1 + n_fft/2)] - """ - - bank_width = int(n_fft // 2 + 1) - if fmax is None: - fmax = sample_rate / 2 - if fmin is None: - fmin = 0 - assert fmin >= 0, "fmin cannot be negative" - assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" - - def mel(f): - return 1127.0 * np.log(1.0 + f / 700.0) - - def bin2mel(fft_bin): - return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) - - def f2bin(f): - return int((f * n_fft / sample_rate) + 0.5) - - # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] - klo = f2bin(fmin) + 1 - khi = f2bin(fmax) - - khi = max(khi, klo) - - # Spec 2: SpeechLib uses triangles in Mel space - mlo = mel(fmin) - mhi = mel(fmax) - m_centers = np.linspace(mlo, mhi, n_mels + 2) - ms = (mhi - mlo) / (n_mels + 1) - - matrix = np.zeros((n_mels, bank_width), dtype=np.float32) - for m in range(0, n_mels): - left = m_centers[m] - center = m_centers[m + 1] - right = m_centers[m + 2] - for fft_bin in range(klo, khi): - mbin = bin2mel(fft_bin) - if left < mbin < right: - matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms - - return matrix - - class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"] @@ -123,19 +63,15 @@ class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): self.audio_downsample_rate = audio_downsample_rate self.audio_feat_stride = audio_feat_stride - # TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged. - # self.mel_filters = mel_filter_bank( - # num_frequency_bins=self.n_fft // 2 + 1, - # num_mel_filters=self.feature_size, - # min_frequency=mel_min_frequency, - # max_frequency=mel_max_frequency, - # sampling_rate=self.sampling_rate, - # triangularize_in_mel_space=True, - # mel_scale="kaldi", - # ) - self.mel_filters = speechlib_mel( - self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency - ).T + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_fft // 2 + 1, + num_mel_filters=self.feature_size, + min_frequency=mel_min_frequency, + max_frequency=mel_max_frequency, + sampling_rate=self.sampling_rate, + triangularize_in_mel_space=True, + mel_scale="kaldi", + ) def __call__( self,