[Tests] add min_new_tokens to prevent flaky length checks (#37175)

This commit is contained in:
Joao Gante 2025-04-02 15:24:00 +01:00 committed by GitHub
parent cbfa14823b
commit e90d55ebcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -281,6 +281,7 @@ class GenerationTesterMixin:
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
@ -311,6 +312,7 @@ class GenerationTesterMixin:
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
output_scores=output_scores, output_scores=output_scores,
output_logits=output_logits, output_logits=output_logits,
@ -340,6 +342,7 @@ class GenerationTesterMixin:
output_generate = model.generate( output_generate = model.generate(
do_sample=False, do_sample=False,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_scores=output_scores, output_scores=output_scores,
output_logits=output_logits, output_logits=output_logits,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -370,6 +373,7 @@ class GenerationTesterMixin:
output_generate = model.generate( output_generate = model.generate(
do_sample=True, do_sample=True,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_scores=output_scores, output_scores=output_scores,
output_logits=output_logits, output_logits=output_logits,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -399,6 +403,7 @@ class GenerationTesterMixin:
output_generate = model.generate( output_generate = model.generate(
do_sample=False, do_sample=False,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_scores=output_scores, output_scores=output_scores,
output_logits=output_logits, output_logits=output_logits,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -429,6 +434,7 @@ class GenerationTesterMixin:
output_generate = model.generate( output_generate = model.generate(
do_sample=False, do_sample=False,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_scores=output_scores, output_scores=output_scores,
output_logits=output_logits, output_logits=output_logits,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -464,6 +470,7 @@ class GenerationTesterMixin:
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,