From 1ef86c4f56e204384369059f574dc4a823fb7f07 Mon Sep 17 00:00:00 2001 From: Nicholas Neo <45549785+nicholasneo78@users.noreply.github.com> Date: Fri, 22 Dec 2023 18:47:30 +0800 Subject: [PATCH] Fix: [SeamlessM4T - S2TT] Bug in batch loading of audio in torch.Tensor format in the SeamlessM4TFeatureExtractor class (#27914) * fixes: code fixes on is_batched condition to also check for batched audio data in torch.Tensor format instead of only just checking for batched audio data in np.ndarray format * Update src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * refactor: code refactoring to remove torch framework dependency * docs: updated docstring to add torch tensor compatibility * test: add test cases to incorporate torch tensor inputs * test: ran make fix-copies for code conformity * test: refactor test to separate the test_call into test_call_numpy and test_call_torch --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- .../feature_extraction_seamless_m4t.py | 28 ++++++++++---- .../test_feature_extraction_seamless_m4t.py | 37 ++++++++++++++++++- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 13bb687dd59..1f1e94385f9 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -20,6 +20,12 @@ from typing import List, Optional, Union import numpy as np +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature @@ -152,14 +158,17 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): Main method to featurize and prepare for the model one or several sequence(s). Args: - raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays, a list of list of float values or a list of a list of list of float - values. If `raw_speech` is a one-dimensional `np.ndarray` or a `List[float]`, `raw_speech` is + raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`, + `List[List[float]]`, `List[List[List[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, + a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors, + a list of list of float values or a list of a list of list of float values. + If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is considered a single-channel, single-sample sound. In all other cases, the first dimension of - `raw_speech`, whether from an `np.ndarray` or a `List[...]`, corresponds to the number of samples in - the batch, and the number of channels (i.e. mono or stereo character) is derived from the other - dimensions (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). + `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`, + corresponds to the number of samples in the batch, and the number of channels + (i.e. mono or stereo character) is derived from the other dimensions + (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: @@ -224,8 +233,11 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): if is_batched_numpy and len(raw_speech.shape) > 3: raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") + acceptable_types = ( + (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list) + ) is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types)) ) if is_batched: diff --git a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py index 2b55a61d812..8ea1025f0ee 100644 --- a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py @@ -139,7 +139,7 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt dict_second = feat_extract_second.to_dict() self.assertEqual(dict_first, dict_second) - def test_call(self): + def test_call_numpy(self): # Tests that all call wrap to encode_plus and batch_encode_plus feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) # create three inputs of length 800, 1000, and 1200 @@ -171,6 +171,41 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + @require_torch + def test_call_torch(self): + import torch + + # Tests that all call wrap to encode_plus and batch_encode_plus + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + pt_speech_inputs = [torch.tensor(speech_input) for speech_input in speech_inputs] + + # Test feature size + input_features = feature_extractor(pt_speech_inputs, padding=True, return_tensors="pt").input_features + self.assertTrue(input_features.ndim == 3) + self.assertTrue(input_features.shape[0] == 3) + self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride) + + # Test not batched input + encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="pt").input_features + encoded_sequences_2 = feature_extractor(pt_speech_inputs[0], return_tensors="pt").input_features + self.assertTrue(torch.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features + encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + pt_speech_inputs = torch.tensor(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="pt").input_features + encoded_sequences_2 = feature_extractor(pt_speech_inputs, return_tensors="pt").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(torch.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + @require_torch # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad def test_double_precision_pad(self):