mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[trainer] make generate work with multigpu (#8716)
* make generate work with multigpu * better fix - thanks @sgugger
This commit is contained in:
parent
900024273b
commit
1e45bef0a7
@ -189,7 +189,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||||
generated_tokens = model.generate(
|
generated_tokens = self.model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
**gen_kwargs,
|
**gen_kwargs,
|
||||||
|
@ -4,13 +4,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel
|
from transformers import BertTokenizer, EncoderDecoderModel
|
||||||
from transformers.file_utils import is_datasets_available
|
from transformers.file_utils import is_datasets_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
|
||||||
TestCasePlus,
|
|
||||||
execute_subprocess_async,
|
|
||||||
get_gpu_count,
|
|
||||||
require_torch_non_multi_gpu_but_fix_me,
|
|
||||||
slow,
|
|
||||||
)
|
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
@ -52,7 +46,6 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
assert "test_results.json" in contents
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_non_multi_gpu_but_fix_me
|
|
||||||
def test_finetune_bert2bert(self):
|
def test_finetune_bert2bert(self):
|
||||||
if not is_datasets_available():
|
if not is_datasets_available():
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user