mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[FIX] Fix speech2test modeling tests (#29672)
* fix speech_to_test generation tests * Add details to comment * Update tests/models/speech_to_text/test_modeling_speech_to_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9e4df7c424
commit
4e98d59443
@ -284,6 +284,18 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self)
|
||||
|
||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
||||
# attention mask of the same shape as `input_ids`.
|
||||
if len(attention_mask.shape) > 2:
|
||||
sequence_length = input_ids.shape[1]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Speech2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
||||
|
Loading…
Reference in New Issue
Block a user