mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add tokenizer to Trainer (#6689)
This commit is contained in:
parent
abc0202194
commit
124c3d6adc
@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
||||
from tqdm.auto import tqdm, trange
|
||||
|
||||
from .data.data_collator import DataCollator, default_data_collator
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import is_nlp_available, is_torch_tpu_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
@ -31,6 +31,7 @@ from .integrations import (
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
@ -168,15 +169,20 @@ class Trainer:
|
||||
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
|
||||
data_collator (:obj:`DataCollator`, `optional`):
|
||||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
||||
:obj:`eval_dataset`.
|
||||
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
|
||||
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
|
||||
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
||||
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
|
||||
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
||||
interrupted training or reuse the fine-tuned model.
|
||||
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||
A function that instantiates the model to be used. If provided, each call to
|
||||
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||
@ -200,6 +206,7 @@ class Trainer:
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
@ -218,9 +225,11 @@ class Trainer:
|
||||
if model is None and model_init is not None:
|
||||
model = model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.model_init = model_init
|
||||
self.compute_metrics = compute_metrics
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
@ -1091,6 +1100,8 @@ class Trainer:
|
||||
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
@ -1101,6 +1112,8 @@ class Trainer:
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
|
Loading…
Reference in New Issue
Block a user