mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix error in M4T feature extractor (#28340)
* fix M4T FE error when no attention mask * modify logic * add test * go back to initial test situation + add other tests
This commit is contained in:
parent
4a66c0d952
commit
35e9d2b223
@ -229,6 +229,10 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
return_attention_mask = (
|
||||
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
|
||||
)
|
||||
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 3:
|
||||
raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
|
||||
@ -270,13 +274,13 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_attention_mask=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# SeamlessM4T needs to process extracted features
|
||||
input_features = padded_inputs.get("input_features")
|
||||
attention_mask = padded_inputs.get("attention_mask")
|
||||
attention_mask = padded_inputs.pop("attention_mask")
|
||||
|
||||
batch_size, num_frames, num_channels = input_features.shape
|
||||
|
||||
@ -293,7 +297,8 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
||||
attention_mask = attention_mask[:, indices % self.stride == 1]
|
||||
|
||||
padded_inputs["input_features"] = input_features
|
||||
padded_inputs["attention_mask"] = attention_mask
|
||||
if return_attention_mask:
|
||||
padded_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
@ -171,6 +171,42 @@ class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_call_without_attention_mask(self):
|
||||
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test attention mask when passing no attention mask to forward call
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
|
||||
self.assertTrue("attention_mask" not in output)
|
||||
|
||||
# Test attention mask when no attention mask by default
|
||||
feature_extractor_args["return_attention_mask"] = False
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np", return_attention_mask=False)
|
||||
self.assertTrue("attention_mask" not in output)
|
||||
|
||||
def test_attention_mask(self):
|
||||
# test attention mask has the right output shape
|
||||
feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
feature_extractor = self.feature_extraction_class(**feature_extractor_args)
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test attention mask when passing it to forward call
|
||||
output = feature_extractor(np_speech_inputs, padding=True, return_tensors="np")
|
||||
input_features = output.input_features
|
||||
|
||||
attention_mask = output.attention_mask
|
||||
self.assertTrue(attention_mask.ndim == 2)
|
||||
self.assertTrue(attention_mask.shape[0] == 3)
|
||||
self.assertTrue(attention_mask.shape[-1] == input_features.shape[1])
|
||||
|
||||
@require_torch
|
||||
def test_call_torch(self):
|
||||
import torch
|
||||
|
Loading…
Reference in New Issue
Block a user