mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix failing test_batch_generation
for bloom (#25718)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f01459c75d
commit
8fff61b9db
@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
|
||||
|
||||
input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
|
||||
input_ids = input_ids["input_ids"].to(torch_device)
|
||||
attention_mask = input_ids["attention_mask"]
|
||||
inputs = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
attention_mask = inputs["attention_mask"]
|
||||
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
|
||||
|
||||
self.assertEqual(
|
||||
|
Loading…
Reference in New Issue
Block a user