Fix failing test_batch_generation for bloom (#25718)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-08-24 11:15:29 +02:00 committed by GitHub
parent f01459c75d
commit 8fff61b9db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
input_ids = input_ids["input_ids"].to(torch_device)
attention_mask = input_ids["attention_mask"]
inputs = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"]
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
self.assertEqual(