[trainer] make generate work with multigpu (#8716)

* make generate work with multigpu

* better fix - thanks @sgugger
This commit is contained in:
Stas Bekman 2020-11-23 10:57:27 -08:00 committed by GitHub
parent 900024273b
commit 1e45bef0a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 9 deletions

View File

@ -189,7 +189,7 @@ class Seq2SeqTrainer(Trainer):
}
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,

View File

@ -4,13 +4,7 @@ from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_non_multi_gpu_but_fix_me,
slow,
)
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
@ -52,7 +46,6 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_results.json" in contents
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_bert2bert(self):
if not is_datasets_available():
return