mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 06:40:04 +06:00

* Important files
* Styling them all
* Revert "Styling them all"
This reverts commit 7d029395fd
.
* Syling them for realsies
* Fix syntax error
* Fix benchmark_utils
* More fixes
* Fix modeling auto and script
* Remove new line
* Fixes
* More fixes
* Fix more files
* Style
* Add FSMT
* More fixes
* More fixes
* More fixes
* More fixes
* Fixes
* More fixes
* More fixes
* Last fixes
* Make sphinx happy
73 lines
3.3 KiB
ReStructuredText
73 lines
3.3 KiB
ReStructuredText
Trainer
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
The :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` classes provide an API for feature-complete
|
|
training in most standard use cases. It's used in most of the :doc:`example scripts <../examples>`.
|
|
|
|
Before instantiating your :class:`~transformers.Trainer`/:class:`~transformers.TFTrainer`, create a
|
|
:class:`~transformers.TrainingArguments`/:class:`~transformers.TFTrainingArguments` to access all the points of
|
|
customization during training.
|
|
|
|
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
|
|
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
|
|
|
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
|
|
previous features. To inject custom behavior you can subclass them and override the following methods:
|
|
|
|
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
|
|
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
|
|
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
|
- **log** -- Logs information on the various objects watching training.
|
|
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
|
init.
|
|
- **compute_loss** - Computes the loss on a batch of training inputs.
|
|
- **training_step** -- Performs a training step.
|
|
- **prediction_step** -- Performs an evaluation/test step.
|
|
- **run_model** (TensorFlow only) -- Basic pass through the model.
|
|
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
|
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
|
|
|
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
|
|
|
|
.. code-block:: python
|
|
|
|
from transformers import Trainer
|
|
class MyTrainer(Trainer):
|
|
def compute_loss(self, model, inputs):
|
|
labels = inputs.pop("labels")
|
|
outputs = models(**inputs)
|
|
logits = outputs[0]
|
|
return my_custom_loss(logits, labels)
|
|
|
|
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
|
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
|
other ML platforms...) and take decisions (like early stopping).
|
|
|
|
|
|
Trainer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.Trainer
|
|
:members:
|
|
|
|
|
|
TFTrainer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFTrainer
|
|
:members:
|
|
|
|
|
|
TrainingArguments
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TrainingArguments
|
|
:members:
|
|
|
|
|
|
TFTrainingArguments
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFTrainingArguments
|
|
:members:
|