Generate: pin number of beams in BART test (#22763)

This commit is contained in:
Joao Gante 2023-04-14 09:57:25 +01:00 committed by GitHub
parent 66b15efb20
commit 9af845afc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1230,7 +1230,7 @@ class BartModelIntegrationTests(unittest.TestCase):
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, num_beams=1)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(