.. Copyright 2020 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Trainer ----------------------------------------------------------------------------------------------------------------------- The :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` classes provide an API for feature-complete training in most standard use cases. It's used in most of the :doc:`example scripts <../examples>`. Before instantiating your :class:`~transformers.Trainer`/:class:`~transformers.TFTrainer`, create a :class:`~transformers.TrainingArguments`/:class:`~transformers.TFTrainingArguments` to access all the points of customization during training. The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex `__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the previous features. To inject custom behavior you can subclass them and override the following methods: - **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset. - **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset. - **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset. - **log** -- Logs information on the various objects watching training. - **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at init. - **compute_loss** - Computes the loss on a batch of training inputs. - **training_step** -- Performs a training step. - **prediction_step** -- Performs an evaluation/test step. - **run_model** (TensorFlow only) -- Basic pass through the model. - **evaluate** -- Runs an evaluation loop and returns metrics. - **predict** -- Returns predictions (with metrics if labels are available) on a test set. Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function: .. code-block:: python from transformers import Trainer class MyTrainer(Trainer): def compute_loss(self, model, inputs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs[0] return my_custom_loss(logits, labels) Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use :doc:`callbacks ` that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early stopping). Trainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.Trainer :members: Seq2SeqTrainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.Seq2SeqTrainer :members: evaluate, predict TFTrainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainer :members: TrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TrainingArguments :members: Seq2SeqTrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.Seq2SeqTrainingArguments :members: TFTrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainingArguments :members: