mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flaky Generation Tests] Make sure that no early stopping is happening for beam search (#9794)
* fix ci * fix ci * renaming * fix dup line
This commit is contained in:
parent
0fdbf0850a
commit
d94cc2f904
@ -625,6 +625,12 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
@ -669,9 +675,16 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
@ -715,11 +728,15 @@ class GenerationTesterMixin:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
# only relevant if model has "use_cache"
|
||||
return
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
@ -758,7 +775,12 @@ class GenerationTesterMixin:
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
print("Return dict", config.return_dict)
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -788,9 +810,16 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
@ -859,6 +888,11 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
||||
)
|
||||
@ -904,6 +938,12 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
|
Loading…
Reference in New Issue
Block a user