Fix: [SeamlessM4T - S2TT] Bug in batch loading of audio in torch.Tensor format in the SeamlessM4TFeatureExtractor class (#27914)

* fixes: code fixes on is_batched condition to also check for batched audio data in torch.Tensor format instead of only just checking for batched audio data in np.ndarray format

* Update src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* refactor: code refactoring to remove torch framework dependency

* docs: updated docstring to add torch tensor compatibility

* test: add test cases to incorporate torch tensor inputs

* test: ran make fix-copies for code conformity

* test: refactor test to separate the test_call into test_call_numpy and test_call_torch

---------

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
Nicholas Neo 2023-12-22 18:47:30 +08:00 committed by GitHub
parent 548a8f6119
commit 1ef86c4f56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 9 deletions

View File

@ -20,6 +20,12 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from ...utils import is_torch_available
if is_torch_available():
import torch
from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
@ -152,14 +158,17 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
Main method to featurize and prepare for the model one or several sequence(s). Main method to featurize and prepare for the model one or several sequence(s).
Args: Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`): raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`,
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float `List[List[float]]`, `List[List[List[float]]]`):
values, a list of numpy arrays, a list of list of float values or a list of a list of list of float The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
values. If `raw_speech` is a one-dimensional `np.ndarray` or a `List[float]`, `raw_speech` is a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
a list of list of float values or a list of a list of list of float values.
If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is
considered a single-channel, single-sample sound. In all other cases, the first dimension of considered a single-channel, single-sample sound. In all other cases, the first dimension of
`raw_speech`, whether from an `np.ndarray` or a `List[...]`, corresponds to the number of samples in `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`,
the batch, and the number of channels (i.e. mono or stereo character) is derived from the other corresponds to the number of samples in the batch, and the number of channels
dimensions (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). (i.e. mono or stereo character) is derived from the other dimensions
(1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among: index) among:
@ -224,8 +233,11 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
if is_batched_numpy and len(raw_speech.shape) > 3: if is_batched_numpy and len(raw_speech.shape) > 3:
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
acceptable_types = (
(torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
)
is_batched = is_batched_numpy or ( is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
) )
if is_batched: if is_batched:

View File

@ -139,7 +139,7 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
dict_second = feat_extract_second.to_dict() dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second) self.assertEqual(dict_first, dict_second)
def test_call(self): def test_call_numpy(self):
# Tests that all call wrap to encode_plus and batch_encode_plus # Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200 # create three inputs of length 800, 1000, and 1200
@ -171,6 +171,41 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_call_torch(self):
import torch
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
pt_speech_inputs = [torch.tensor(speech_input) for speech_input in speech_inputs]
# Test feature size
input_features = feature_extractor(pt_speech_inputs, padding=True, return_tensors="pt").input_features
self.assertTrue(input_features.ndim == 3)
self.assertTrue(input_features.shape[0] == 3)
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride)
# Test not batched input
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs[0], return_tensors="pt").input_features
self.assertTrue(torch.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
pt_speech_inputs = torch.tensor(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch @require_torch
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self): def test_double_precision_pad(self):