mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Generate Test] fix greedy generate test (#8293)
* fix greedy generate test * delet ipdb
This commit is contained in:
parent
734afa37f6
commit
cb966e640b
@ -140,10 +140,6 @@ class GenerationTesterMixin:
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
max_length = 4
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
@ -154,6 +150,13 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_greedy = model.greedy_search(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user