Fix test_t5_decoder_model_past_large_inputs (#17320)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-05-18 17:57:23 +02:00 committed by GitHub
parent 6da76b9c2a
commit b3b9f99ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -295,6 +295,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
def test_t5_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# `create_and_check_t5_decoder_model_past_large_inputs` has special inputs:
# (config, input_ids, decoder_input_ids, attention_mask)
# and we have to prepare it correctly here.
config, input_ids, input_mask, token_labels = config_and_inputs
config_and_inputs = (config, input_ids, None, input_mask)
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate_fast(self):