Fix error in M4T feature extractor (#28340)

* fix M4T FE error when no attention mask

* modify logic

* add test

* go back to initial test situation + add other tests
This commit is contained in:
Yoach Lacombe 2024-01-04 17:40:53 +01:00 committed by GitHub
parent 4a66c0d952
commit 35e9d2b223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 3 deletions

View File

@ -229,6 +229,10 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
return_attention_mask = (
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
)
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 3:
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
@ -270,13 +274,13 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_attention_mask=True,
return_tensors="np",
)
# SeamlessM4T needs to process extracted features
input_features = padded_inputs.get("input_features")
attention_mask = padded_inputs.get("attention_mask")
attention_mask = padded_inputs.pop("attention_mask")
batch_size, num_frames, num_channels = input_features.shape
@ -293,7 +297,8 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
attention_mask = attention_mask[:, indices % self.stride == 1]
padded_inputs["input_features"] = input_features
padded_inputs["attention_mask"] = attention_mask
if return_attention_mask:
padded_inputs["attention_mask"] = attention_mask
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

View File

@ -171,6 +171,42 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_call_without_attention_mask(self):
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test attention mask when passing no attention mask to forward call
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
self.assertTrue("attention_mask" not in output)
# Test attention mask when no attention mask by default
feature_extractor_args["return_attention_mask"] = False
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
self.assertTrue("attention_mask" not in output)
def test_attention_mask(self):
# test attention mask has the right output shape
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test attention mask when passing it to forward call
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np")
input_features = output.input_features
attention_mask = output.attention_mask
self.assertTrue(attention_mask.ndim == 2)
self.assertTrue(attention_mask.shape[0] == 3)
self.assertTrue(attention_mask.shape[-1] == input_features.shape[1])
@require_torch
def test_call_torch(self):
import torch