Fix: [SeamlessM4T - S2TT] Bug in batch loading of audio in torch.Tensor format in the SeamlessM4TFeatureExtractor class (#27914)

* fixes: code fixes on is_batched condition to also check for batched audio data in torch.Tensor format instead of only just checking for batched audio data in np.ndarray format

* Update src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* refactor: code refactoring to remove torch framework dependency

* docs: updated docstring to add torch tensor compatibility

* test: add test cases to incorporate torch tensor inputs

* test: ran make fix-copies for code conformity

* test: refactor test to separate the test_call into test_call_numpy and test_call_torch

---------

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
Nicholas Neo 2023-12-22 18:47:30 +08:00 committed by GitHub
parent 548a8f6119
commit 1ef86c4f56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 9 deletions

View File

@ -20,6 +20,12 @@ from typing import List, Optional, Union
import numpy as np
from ...utils import is_torch_available
if is_torch_available():
import torch
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
@ -152,14 +158,17 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays, a list of list of float values or a list of a list of list of float
values. If `raw_speech` is a one-dimensional `np.ndarray` or a `List[float]`, `raw_speech` is
raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`,
`List[List[float]]`, `List[List[List[float]]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
a list of list of float values or a list of a list of list of float values.
If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is
considered a single-channel, single-sample sound. In all other cases, the first dimension of
`raw_speech`, whether from an `np.ndarray` or a `List[...]`, corresponds to the number of samples in
the batch, and the number of channels (i.e. mono or stereo character) is derived from the other
dimensions (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
`raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`,
corresponds to the number of samples in the batch, and the number of channels
(i.e. mono or stereo character) is derived from the other dimensions
(1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
@ -224,8 +233,11 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
if is_batched_numpy and len(raw_speech.shape) > 3:
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
acceptable_types = (
(torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
)
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
)
if is_batched:

View File

@ -139,7 +139,7 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second)
def test_call(self):
def test_call_numpy(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
@ -171,6 +171,41 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_call_torch(self):
import torch
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
pt_speech_inputs = [torch.tensor(speech_input) for speech_input in speech_inputs]
# Test feature size
input_features = feature_extractor(pt_speech_inputs, padding=True, return_tensors="pt").input_features
self.assertTrue(input_features.ndim == 3)
self.assertTrue(input_features.shape[0] == 3)
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride)
# Test not batched input
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs[0], return_tensors="pt").input_features
self.assertTrue(torch.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
pt_speech_inputs = torch.tensor(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features
encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self):