[FIX] Fix speech2test modeling tests (#29672)

* fix speech_to_test generation tests

* Add details to comment

* Update tests/models/speech_to_text/test_modeling_speech_to_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2024-03-15 17:58:11 +00:00 committed by GitHub
parent 9e4df7c424
commit 4e98d59443
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -284,6 +284,18 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
input_name = "input_features"
def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self)
# `input_ids` is actually `input_features` which is a 3D tensor.
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
# attention mask of the same shape as `input_ids`.
if len(attention_mask.shape) > 2:
sequence_length = input_ids.shape[1]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
return config, input_ids, attention_mask, max_length
def setUp(self):
self.model_tester = Speech2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)