diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 1f1e94385f9..0d4879a35ea 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -229,6 +229,10 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): "Failing to do so can result in silent errors that might be hard to debug." ) + return_attention_mask = ( + return_attention_mask if return_attention_mask is not None else self.return_attention_mask + ) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 if is_batched_numpy and len(raw_speech.shape) > 3: raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") @@ -270,13 +274,13 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, + return_attention_mask=True, return_tensors="np", ) # SeamlessM4T needs to process extracted features input_features = padded_inputs.get("input_features") - attention_mask = padded_inputs.get("attention_mask") + attention_mask = padded_inputs.pop("attention_mask") batch_size, num_frames, num_channels = input_features.shape @@ -293,7 +297,8 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): attention_mask = attention_mask[:, indices % self.stride == 1] padded_inputs["input_features"] = input_features - padded_inputs["attention_mask"] = attention_mask + if return_attention_mask: + padded_inputs["attention_mask"] = attention_mask if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py index 8ea1025f0ee..a8fca4b90ba 100644 --- a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py @@ -171,6 +171,42 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_call_without_attention_mask(self): + feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict() + feature_extractor = self.feature_extraction_class(**feature_extractor_args) + + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test attention mask when passing no attention mask to forward call + output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False) + self.assertTrue("attention_mask" not in output) + + # Test attention mask when no attention mask by default + feature_extractor_args["return_attention_mask"] = False + feature_extractor = self.feature_extraction_class(**feature_extractor_args) + output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False) + self.assertTrue("attention_mask" not in output) + + def test_attention_mask(self): + # test attention mask has the right output shape + feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict() + + feature_extractor = self.feature_extraction_class(**feature_extractor_args) + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test attention mask when passing it to forward call + output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np") + input_features = output.input_features + + attention_mask = output.attention_mask + self.assertTrue(attention_mask.ndim == 2) + self.assertTrue(attention_mask.shape[0] == 3) + self.assertTrue(attention_mask.shape[-1] == input_features.shape[1]) + @require_torch def test_call_torch(self): import torch