[Generate Test] fix greedy generate test (#8293)

* fix greedy generate test

* delet ipdb
This commit is contained in:
Patrick von Platen 2020-11-04 15:44:36 +01:00 committed by GitHub
parent 734afa37f6
commit cb966e640b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -140,10 +140,6 @@ class GenerationTesterMixin:
# check `generate()` and `greedy_search()` are equal
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
max_length = 4
output_ids_generate = model.generate(
@ -154,6 +150,13 @@ class GenerationTesterMixin:
max_length=max_length,
**logits_process_kwargs,
)
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
output_ids_greedy = model.greedy_search(
input_ids,