mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Fix: [SeamlessM4T - S2TT] Bug in batch loading of audio in torch.Tensor format in the SeamlessM4TFeatureExtractor class (#27914)
* fixes: code fixes on is_batched condition to also check for batched audio data in torch.Tensor format instead of only just checking for batched audio data in np.ndarray format * Update src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * refactor: code refactoring to remove torch framework dependency * docs: updated docstring to add torch tensor compatibility * test: add test cases to incorporate torch tensor inputs * test: ran make fix-copies for code conformity * test: refactor test to separate the test_call into test_call_numpy and test_call_torch --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
parent
548a8f6119
commit
1ef86c4f56
@ -20,6 +20,12 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ...utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
@ -152,14 +158,17 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
Main method to featurize and prepare for the model one or several sequence(s).
|
Main method to featurize and prepare for the model one or several sequence(s).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`):
|
raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`,
|
||||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
`List[List[float]]`, `List[List[List[float]]]`):
|
||||||
values, a list of numpy arrays, a list of list of float values or a list of a list of list of float
|
The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
|
||||||
values. If `raw_speech` is a one-dimensional `np.ndarray` or a `List[float]`, `raw_speech` is
|
a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
|
||||||
|
a list of list of float values or a list of a list of list of float values.
|
||||||
|
If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is
|
||||||
considered a single-channel, single-sample sound. In all other cases, the first dimension of
|
considered a single-channel, single-sample sound. In all other cases, the first dimension of
|
||||||
`raw_speech`, whether from an `np.ndarray` or a `List[...]`, corresponds to the number of samples in
|
`raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`,
|
||||||
the batch, and the number of channels (i.e. mono or stereo character) is derived from the other
|
corresponds to the number of samples in the batch, and the number of channels
|
||||||
dimensions (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
|
(i.e. mono or stereo character) is derived from the other dimensions
|
||||||
|
(1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
|
||||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||||
index) among:
|
index) among:
|
||||||
@ -224,8 +233,11 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
if is_batched_numpy and len(raw_speech.shape) > 3:
|
if is_batched_numpy and len(raw_speech.shape) > 3:
|
||||||
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
|
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
|
||||||
|
|
||||||
|
acceptable_types = (
|
||||||
|
(torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
|
||||||
|
)
|
||||||
is_batched = is_batched_numpy or (
|
is_batched = is_batched_numpy or (
|
||||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_batched:
|
if is_batched:
|
||||||
|
@ -139,7 +139,7 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
|||||||
dict_second = feat_extract_second.to_dict()
|
dict_second = feat_extract_second.to_dict()
|
||||||
self.assertEqual(dict_first, dict_second)
|
self.assertEqual(dict_first, dict_second)
|
||||||
|
|
||||||
def test_call(self):
|
def test_call_numpy(self):
|
||||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
# create three inputs of length 800, 1000, and 1200
|
# create three inputs of length 800, 1000, and 1200
|
||||||
@ -171,6 +171,41 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_call_torch(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
# create three inputs of length 800, 1000, and 1200
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
pt_speech_inputs = [torch.tensor(speech_input) for speech_input in speech_inputs]
|
||||||
|
|
||||||
|
# Test feature size
|
||||||
|
input_features = feature_extractor(pt_speech_inputs, padding=True, return_tensors="pt").input_features
|
||||||
|
self.assertTrue(input_features.ndim == 3)
|
||||||
|
self.assertTrue(input_features.shape[0] == 3)
|
||||||
|
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="pt").input_features
|
||||||
|
encoded_sequences_2 = feature_extractor(pt_speech_inputs[0], return_tensors="pt").input_features
|
||||||
|
self.assertTrue(torch.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
|
||||||
|
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
|
||||||
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
|
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
# Test 2-D numpy arrays are batched.
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||||
|
pt_speech_inputs = torch.tensor(speech_inputs)
|
||||||
|
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
|
||||||
|
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
|
||||||
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
|
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
|
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
|
||||||
def test_double_precision_pad(self):
|
def test_double_precision_pad(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user