mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Gemma 3 tests expect greedy decoding (#36882)
tests expect greedy decoding
This commit is contained in:
parent
b8aadc31d5
commit
2638d54e78
@ -567,7 +567,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
@ -599,6 +599,11 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
# Generation works beyond sliding window
|
||||
self.assertGreater(out.shape[1], model.config.sliding_window)
|
||||
self.assertEqual(out.shape[1], input_size + 5)
|
||||
|
Loading…
Reference in New Issue
Block a user