From 1e45bef0a733b7115304501565ca885ae11ad32d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 23 Nov 2020 10:57:27 -0800 Subject: [PATCH] [trainer] make generate work with multigpu (#8716) * make generate work with multigpu * better fix - thanks @sgugger --- examples/seq2seq/seq2seq_trainer.py | 2 +- examples/seq2seq/test_finetune_trainer.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 99826e2228f..d3af5dce436 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -189,7 +189,7 @@ class Seq2SeqTrainer(Trainer): } if self.args.predict_with_generate and not self.args.prediction_loss_only: - generated_tokens = model.generate( + generated_tokens = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs, diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index b8c0f4816ce..70cceae3c52 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -4,13 +4,7 @@ from unittest.mock import patch from transformers import BertTokenizer, EncoderDecoderModel from transformers.file_utils import is_datasets_available -from transformers.testing_utils import ( - TestCasePlus, - execute_subprocess_async, - get_gpu_count, - require_torch_non_multi_gpu_but_fix_me, - slow, -) +from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow from transformers.trainer_callback import TrainerState from transformers.trainer_utils import set_seed @@ -52,7 +46,6 @@ class TestFinetuneTrainer(TestCasePlus): assert "test_results.json" in contents @slow - @require_torch_non_multi_gpu_but_fix_me def test_finetune_bert2bert(self): if not is_datasets_available(): return