Gemma 3 tests expect greedy decoding (#36882)

tests expect greedy decoding
This commit is contained in:
Pablo Montalvo 2025-03-21 12:36:39 +01:00 committed by GitHub
parent b8aadc31d5
commit 2638d54e78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -567,7 +567,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
@ -599,6 +599,11 @@ class Gemma3IntegrationTest(unittest.TestCase):
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
out = model.generate(**inputs, generation_config=generation_config)
out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
# Generation works beyond sliding window
self.assertGreater(out.shape[1], model.config.sliding_window)
self.assertEqual(out.shape[1], input_size + 5)