Correct naming pegasus x (#18896)

* add first generation tutorial

* [Pegasus X] correct naming

* [Generation] Remove
This commit is contained in:
Patrick von Platen 2022-09-05 11:25:00 +02:00 committed by GitHub
parent 591cfc6c90
commit badb9d2aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -559,7 +559,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
return PegasusTokenizer.from_pretrained("google/pegasus-x-base")
def test_inference_no_head(self):
model = PegasusXModel.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXModel.from_pretrained("google/pegasus-x-base").to(torch_device)
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]])
inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids)
@ -574,7 +574,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self):
model = PegasusXForConditionalGeneration.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base").to(torch_device)
# change to intended input
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])