mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Add tokenizer to Trainer (#6689)
This commit is contained in:
parent
abc0202194
commit
124c3d6adc
@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
||||||
from tqdm.auto import tqdm, trange
|
from tqdm.auto import tqdm, trange
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .file_utils import is_nlp_available, is_torch_tpu_available
|
from .file_utils import is_nlp_available, is_torch_tpu_available
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
default_hp_search_backend,
|
default_hp_search_backend,
|
||||||
@ -31,6 +31,7 @@ from .integrations import (
|
|||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||||
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_utils import (
|
from .trainer_utils import (
|
||||||
PREFIX_CHECKPOINT_DIR,
|
PREFIX_CHECKPOINT_DIR,
|
||||||
BestRun,
|
BestRun,
|
||||||
@ -168,15 +169,20 @@ class Trainer:
|
|||||||
args (:class:`~transformers.TrainingArguments`, `optional`):
|
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||||
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||||
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||||
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
|
data_collator (:obj:`DataCollator`, `optional`):
|
||||||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
||||||
:obj:`eval_dataset`.
|
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
|
||||||
|
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
|
||||||
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed.
|
||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed.
|
||||||
|
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
||||||
|
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
|
||||||
|
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
||||||
|
interrupted training or reuse the fine-tuned model.
|
||||||
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||||
A function that instantiates the model to be used. If provided, each call to
|
A function that instantiates the model to be used. If provided, each call to
|
||||||
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||||
@ -200,6 +206,7 @@ class Trainer:
|
|||||||
data_collator: Optional[DataCollator] = None,
|
data_collator: Optional[DataCollator] = None,
|
||||||
train_dataset: Optional[Dataset] = None,
|
train_dataset: Optional[Dataset] = None,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
model_init: Callable[[], PreTrainedModel] = None,
|
model_init: Callable[[], PreTrainedModel] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
tb_writer: Optional["SummaryWriter"] = None,
|
tb_writer: Optional["SummaryWriter"] = None,
|
||||||
@ -218,9 +225,11 @@ class Trainer:
|
|||||||
if model is None and model_init is not None:
|
if model is None and model_init is not None:
|
||||||
model = model_init()
|
model = model_init()
|
||||||
self.model = model.to(args.device) if model is not None else None
|
self.model = model.to(args.device) if model is not None else None
|
||||||
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
|
self.tokenizer = tokenizer
|
||||||
self.model_init = model_init
|
self.model_init = model_init
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.optimizer, self.lr_scheduler = optimizers
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
@ -1091,6 +1100,8 @@ class Trainer:
|
|||||||
|
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None):
|
def _save(self, output_dir: Optional[str] = None):
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
@ -1101,6 +1112,8 @@ class Trainer:
|
|||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
|
Loading…
Reference in New Issue
Block a user