mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 18:18:24 +06:00
Gemma 3 tests expect greedy decoding (#36882)
tests expect greedy decoding
This commit is contained in:
parent
b8aadc31d5
commit
2638d54e78
@ -567,7 +567,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
input_size = inputs.input_ids.shape[-1]
|
input_size = inputs.input_ids.shape[-1]
|
||||||
self.assertTrue(input_size > model.config.sliding_window)
|
self.assertTrue(input_size > model.config.sliding_window)
|
||||||
|
|
||||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
|
||||||
output_text = tokenizer.batch_decode(out)
|
output_text = tokenizer.batch_decode(out)
|
||||||
|
|
||||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
@ -599,6 +599,11 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
||||||
out = model.generate(**inputs, generation_config=generation_config)
|
out = model.generate(**inputs, generation_config=generation_config)
|
||||||
|
|
||||||
|
out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:]
|
||||||
|
output_text = tokenizer.batch_decode(out)
|
||||||
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||||
|
|
||||||
# Generation works beyond sliding window
|
# Generation works beyond sliding window
|
||||||
self.assertGreater(out.shape[1], model.config.sliding_window)
|
self.assertGreater(out.shape[1], model.config.sliding_window)
|
||||||
self.assertEqual(out.shape[1], input_size + 5)
|
self.assertEqual(out.shape[1], input_size + 5)
|
||||||
|
Loading…
Reference in New Issue
Block a user