mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 06:20:22 +06:00
Correct naming pegasus x (#18896)
* add first generation tutorial * [Pegasus X] correct naming * [Generation] Remove
This commit is contained in:
parent
591cfc6c90
commit
badb9d2aaa
@ -559,7 +559,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
|
||||
return PegasusTokenizer.from_pretrained("google/pegasus-x-base")
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = PegasusXModel.from_pretrained("pegasus-x-base").to(torch_device)
|
||||
model = PegasusXModel.from_pretrained("google/pegasus-x-base").to(torch_device)
|
||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]])
|
||||
inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids)
|
||||
@ -574,7 +574,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_inference_head(self):
|
||||
model = PegasusXForConditionalGeneration.from_pretrained("pegasus-x-base").to(torch_device)
|
||||
model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base").to(torch_device)
|
||||
|
||||
# change to intended input
|
||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
|
Loading…
Reference in New Issue
Block a user