Fix CI for PegasusX (#19025)

* Skip test_torchscript_output_attentions for PegasusXModelTest

* fix test_inference_no_head

* fix test_inference_head

* fix test_seq_to_seq_generation

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-09-14 14:45:00 +02:00 committed by GitHub
parent 77ea35b93a
commit 77b18783c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -206,6 +206,12 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
self.model_tester = PegasusXModelTester(self)
self.config_tester = ConfigTester(self, config_class=PegasusXConfig)
@unittest.skip(
"`PegasusXGlobalLocalAttention` returns attentions as dictionary - not compatible with torchscript "
)
def test_torchscript_output_attentions(self):
pass
def test_config(self):
self.config_tester.run_common_tests()
@ -565,12 +571,13 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids)
with torch.no_grad():
output = model(**inputs_dict)[0]
expected_shape = torch.Size((1, 11, 1024))
expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self):
@ -586,13 +593,13 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
def test_seq_to_seq_generation(self):
hf = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base").to(torch_device)
tok = PegasusTokenizer.from_pretrained("google/pegasus-x-large")
hf = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base-arxiv").to(torch_device)
tok = PegasusTokenizer.from_pretrained("google/pegasus-x-base")
batch_input = [
"While large pretrained Transformer models have proven highly capable at tackling natural language tasks,"
@ -626,7 +633,8 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
)
EXPECTED = [
"we investigate the performance of a new pretrained model for long input summarization. <n> the model"
"we investigate the performance of a new pretrained model for long input summarization. <n> the model is a"
" superposition of two well -"
]
generated = tok.batch_decode(