Add tokenizer to Trainer (#6689)

This commit is contained in:
Sylvain Gugger 2020-08-25 07:47:09 -04:00 committed by GitHub
parent abc0202194
commit 124c3d6adc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, default_data_collator
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import is_nlp_available, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
@ -31,6 +31,7 @@ from .integrations import (
)
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
@ -168,15 +169,20 @@ class Trainer:
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
data_collator (:obj:`DataCollator`, `optional`):
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
:obj:`eval_dataset`.
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
@ -200,6 +206,7 @@ class Trainer:
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
tb_writer: Optional["SummaryWriter"] = None,
@ -218,9 +225,11 @@ class Trainer:
if model is None and model_init is not None:
model = model_init()
self.model = model.to(args.device) if model is not None else None
self.data_collator = data_collator if data_collator is not None else default_data_collator
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.model_init = model_init
self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers
@ -1091,6 +1100,8 @@ class Trainer:
xm.rendezvous("saving_checkpoint")
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
@ -1101,6 +1112,8 @@ class Trainer:
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))