transformers/docs/source/main_classes/trainer.mdx
Sylvain Gugger 459677aebe
PoC for conserving old links (#14754)
* PoC for conserving old links

* Do the same for other links

* remap the redirects section

* add instructions on how to move sections

* improve

Co-authored-by: Stas Bekman <stas@stason.org>
2021-12-15 11:40:47 -08:00

471 lines
20 KiB
Plaintext

<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Trainer
The [`Trainer`] class provides an API for feature-complete training in PyTorch for most standard use cases. It's used in most of the [example scripts](../examples).
Before instantiating your [`Trainer`], create a [`TrainingArguments`] to access all the points of customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through [NVIDIA Apex](https://github.com/NVIDIA/apex) and Native AMP for PyTorch.
The [`Trainer`] contains the basic training loop which supports the above features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader** -- Creates the training DataLoader.
- **get_eval_dataloader** -- Creates the evaluation DataLoader.
- **get_test_dataloader** -- Creates the test DataLoader.
- **log** -- Logs information on the various objects watching training.
- **create_optimizer_and_scheduler** -- Sets up the optimizer and learning rate scheduler if they were not passed at
init. Note, that you can also subclass or override the `create_optimizer` and `create_scheduler` methods
separately.
- **create_optimizer** -- Sets up the optimizer if it wasn't passed at init.
- **create_scheduler** -- Sets up the learning rate scheduler if it wasn't passed at init.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
<Tip warning={true}>
The [`Trainer`] class is optimized for 🤗 Transformers models and can have surprising behaviors
when you use it on other models. When using it on your own model, make sure:
- your model always return tuples or subclasses of [`~file_utils.ModelOutput`].
- your model can compute the loss if a `labels` argument is provided and that loss is returned as the first
element of the tuple (if your model returns tuples)
- your model can accept multiple label arguments (use the `label_names` in your [`TrainingArguments`] to indicate their name to the [`Trainer`]) but none of them should be named `"label"`.
</Tip>
Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification:
```python
from torch import nn
from transformers import Trainer
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get('logits')
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
```
Another way to customize the training loop behavior for the PyTorch [`Trainer`] is to use [callbacks](callback) that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early stopping).
## Trainer
[[autodoc]] Trainer
- all
## Seq2SeqTrainer
[[autodoc]] Seq2SeqTrainer
- evaluate
- predict
## TrainingArguments
[[autodoc]] TrainingArguments
- all
## Seq2SeqTrainingArguments
[[autodoc]] Seq2SeqTrainingArguments
- all
## Checkpoints
By default, [`Trainer`] will save all checkpoints in the `output_dir` you set in the
[`TrainingArguments`] you are using. Those will go in subfolder named `checkpoint-xxx` with xxx
being the step at which the training was at.
Resuming training from a checkpoint can be done when calling [`Trainer.train`] with either:
- `resume_from_checkpoint=True` which will resume training from the latest checkpoint
- `resume_from_checkpoint=checkpoint_dir` which will resume training from the specific checkpoint in the directory
passed.
In addition, you can easily save your checkpoints on the Model Hub when using `push_to_hub=True`. By default, all
the models saved in intermediate checkpoints are saved in different commits, but not the optimizer state. You can adapt
the `hub-strategy` value of your [`TrainingArguments`] to either:
- `"checkpoint"`: the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to
resume training easily with `trainer.train(resume_from_checkpoint="output_dir/last-checkpoint")`.
- `"all_checkpoints"`: all checkpoints are pushed like they appear in the output folder (so you will get one
checkpoint folder per folder in your final repository)
## Logging
By default [`Trainer`] will use `logging.INFO` for the main process and `logging.WARNING` for the replicas if any.
These defaults can be overridden to use any of the 5 `logging` levels with [`TrainingArguments`]'s
arguments:
- `log_level` - for the main process
- `log_level_replica` - for the replicas
Further, if [`TrainingArguments`]'s `log_on_each_node` is set to `False` only the main node will
use the log level settings for its main process, all other nodes will use the log level settings for replicas.
Note that [`Trainer`] is going to set `transformers`'s log level separately for each node in its
[`Trainer.__init__`]. So you may want to set this sooner (see the next example) if you tap into other
`transformers` functionality before creating the [`Trainer`] object.
Here is an example of how this can be used in an application:
```python
[...]
logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# set the main code and the modules it uses to the same log-level according to the node
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)
```
And then if you only want to see warnings on the main node and all other nodes to not print any most likely duplicated
warnings you could run it as:
```bash
my_app.py ... --log_level warning --log_level_replica error
```
In the multi-node environment if you also don't want the logs to repeat for each node's main process, you will want to
change the above to:
```bash
my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0
```
and then only the main process of the first node will log at the "warning" level, and all other processes on the main
node and all processes on other nodes will log at the "error" level.
If you need your application to be as quiet as possible you could do:
```bash
my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0
```
(add `--log_on_each_node 0` if on multi-node environment)
## Randomness
When resuming from a checkpoint generated by [`Trainer`] all efforts are made to restore the
_python_, _numpy_ and _pytorch_ RNG states to the same states as they were at the moment of saving that checkpoint,
which should make the "stop and resume" style of training as close as possible to non-stop training.
However, due to various default non-deterministic pytorch settings this might not fully work. If you want full
determinism please refer to [Controlling sources of randomness](https://pytorch.org/docs/stable/notes/randomness). As explained in the document, that some of those settings
that make things deterministic (.e.g., `torch.backends.cudnn.deterministic`) may slow things down, therefore this
can't be done by default, but you can enable those yourself if needed.
## Trainer Integrations
The [`Trainer`] has been extended to support libraries that may dramatically improve your training
time and fit much bigger models.
Currently it supports third party solutions, [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [FairScale](https://github.com/facebookresearch/fairscale/), which implement parts of the paper [ZeRO: Memory Optimizations
Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He](https://arxiv.org/abs/1910.02054).
This provided support is new and experimental as of this writing.
<a id='zero-install-notes'></a>
### CUDA Extension Installation Notes
As of this writing, both FairScale and Deepspeed require compilation of CUDA C++ code, before they can be used.
While all installation issues should be dealt with through the corresponding GitHub Issues of [FairScale](https://github.com/facebookresearch/fairscale/issues) and [Deepspeed](https://github.com/microsoft/DeepSpeed/issues), there are a few common issues that one may encounter while building
any PyTorch extension that needs to build CUDA extensions.
Therefore, if you encounter a CUDA-related build issue while doing one of the following or both:
```bash
pip install fairscale
pip install deepspeed
```
please, read the following notes first.
In these notes we give examples for what to do when `pytorch` has been built with CUDA `10.2`. If your situation is
different remember to adjust the version number to the one you are after.
#### Possible problem #1
While, Pytorch comes with its own CUDA toolkit, to build these two projects you must have an identical version of CUDA
installed system-wide.
For example, if you installed `pytorch` with `cudatoolkit==10.2` in the Python environment, you also need to have
CUDA `10.2` installed system-wide.
The exact location may vary from system to system, but `/usr/local/cuda-10.2` is the most common location on many
Unix systems. When CUDA is correctly set up and added to the `PATH` environment variable, one can find the
installation location by doing:
```bash
which nvcc
```
If you don't have CUDA installed system-wide, install it first. You will find the instructions by using your favorite
search engine. For example, if you're on Ubuntu you may want to search for: [ubuntu cuda 10.2 install](https://www.google.com/search?q=ubuntu+cuda+10.2+install).
#### Possible problem #2
Another possible common problem is that you may have more than one CUDA toolkit installed system-wide. For example you
may have:
```bash
/usr/local/cuda-10.2
/usr/local/cuda-11.0
```
Now, in this situation you need to make sure that your `PATH` and `LD_LIBRARY_PATH` environment variables contain
the correct paths to the desired CUDA version. Typically, package installers will set these to contain whatever the
last version was installed. If you encounter the problem, where the package build fails because it can't find the right
CUDA version despite you having it installed system-wide, it means that you need to adjust the 2 aforementioned
environment variables.
First, you may look at their contents:
```bash
echo $PATH
echo $LD_LIBRARY_PATH
```
so you get an idea of what is inside.
It's possible that `LD_LIBRARY_PATH` is empty.
`PATH` lists the locations of where executables can be found and `LD_LIBRARY_PATH` is for where shared libraries
are to looked for. In both cases, earlier entries have priority over the later ones. `:` is used to separate multiple
entries.
Now, to tell the build program where to find the specific CUDA toolkit, insert the desired paths to be listed first by
doing:
```bash
export PATH=/usr/local/cuda-10.2/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH
```
Note that we aren't overwriting the existing values, but prepending instead.
Of course, adjust the version number, the full path if need be. Check that the directories you assign actually do
exist. `lib64` sub-directory is where the various CUDA `.so` objects, like `libcudart.so` reside, it's unlikely
that your system will have it named differently, but if it is adjust it to reflect your reality.
#### Possible problem #3
Some older CUDA versions may refuse to build with newer compilers. For example, you my have `gcc-9` but it wants
`gcc-7`.
There are various ways to go about it.
If you can install the latest CUDA toolkit it typically should support the newer compiler.
Alternatively, you could install the lower version of the compiler in addition to the one you already have, or you may
already have it but it's not the default one, so the build system can't see it. If you have `gcc-7` installed but the
build system complains it can't find it, the following might do the trick:
```bash
sudo ln -s /usr/bin/gcc-7 /usr/local/cuda-10.2/bin/gcc
sudo ln -s /usr/bin/g++-7 /usr/local/cuda-10.2/bin/g++
```
Here, we are making a symlink to `gcc-7` from `/usr/local/cuda-10.2/bin/gcc` and since
`/usr/local/cuda-10.2/bin/` should be in the `PATH` environment variable (see the previous problem's solution), it
should find `gcc-7` (and `g++7`) and then the build will succeed.
As always make sure to edit the paths in the example to match your situation.
### FairScale
By integrating [FairScale](https://github.com/facebookresearch/fairscale/) the [`Trainer`]
provides support for the following features from [the ZeRO paper](https://arxiv.org/abs/1910.02054):
1. Optimizer State Sharding
2. Gradient Sharding
3. Model Parameters Sharding (new and very experimental)
4. CPU offload (new and very experimental)
You will need at least two GPUs to use this feature.
**Installation**:
Install the library via pypi:
```bash
pip install fairscale
```
or via `transformers`' `extras`:
```bash
pip install transformers[fairscale]
```
(available starting from `transformers==4.6.0`) or find more details on [the FairScale's GitHub page](https://github.com/facebookresearch/fairscale/#installation).
If you're still struggling with the build, first make sure to read [CUDA Extension Installation Notes](#zero-install-notes).
If it's still not resolved the build issue, here are a few more ideas.
`fairscale` seems to have an issue with the recently introduced by pip build isolation feature. If you have a problem
with it, you may want to try one of:
```bash
pip install fairscale --no-build-isolation .
```
or:
```bash
git clone https://github.com/facebookresearch/fairscale/
cd fairscale
rm -r dist build
python setup.py bdist_wheel
pip uninstall -y fairscale
pip install dist/fairscale-*.whl
```
`fairscale` also has issues with building against pytorch-nightly, so if you use it you may have to try one of:
```bash
pip uninstall -y fairscale; pip install fairscale --pre \
-f https://download.pytorch.org/whl/nightly/cu110/torch_nightly \
--no-cache --no-build-isolation
```
or:
```bash
pip install -v --disable-pip-version-check . \
-f https://download.pytorch.org/whl/nightly/cu110/torch_nightly --pre
```
Of course, adjust the urls to match the cuda version you use.
If after trying everything suggested you still encounter build issues, please, proceed with the GitHub Issue of
[FairScale](https://github.com/facebookresearch/fairscale/issues).
**Usage**:
To use the first version of Sharded data-parallelism, add `--sharded_ddp simple` to the command line arguments, and
make sure you have added the distributed launcher `-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.
For example here is how you could use it for `run_translation.py` with 2 GPUs:
```bash
python -m torch.distributed.launch --nproc_per_node=2 examples/pytorch/translation/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--source_lang en --target_lang ro \
--fp16 --sharded_ddp simple
```
Notes:
- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with `--fp16` too, to make things even faster.
- One of the main benefits of enabling `--sharded_ddp simple` is that it uses a lot less GPU memory, so you should be
able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
significantly shorter training time.
3. To use the second version of Sharded data-parallelism, add `--sharded_ddp zero_dp_2` or `--sharded_ddp zero_dp_3` to the command line arguments, and make sure you have added the distributed launcher `-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.
For example here is how you could use it for `run_translation.py` with 2 GPUs:
```bash
python -m torch.distributed.launch --nproc_per_node=2 examples/pytorch/translation/run_translation.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--source_lang en --target_lang ro \
--fp16 --sharded_ddp zero_dp_2
```
`zero_dp_2` is an optimized version of the simple wrapper, while `zero_dp_3` fully shards model weights,
gradients and optimizer states.
Both are compatible with adding `cpu_offload` to enable ZeRO-offload (activate it like this: `--sharded_ddp "zero_dp_2 cpu_offload"`).
Notes:
- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with `--fp16` too, to make things even faster.
- The `cpu_offload` additional option requires `--fp16`.
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
some bugs you encounter may have been fixed there already.
Known caveats:
- This feature is incompatible with `--predict_with_generate` in the _run_translation.py_ script.
- Using `--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
`FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not
doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`.
Sections that were moved:
[ <a href="./deepspeed#deepspeed-trainer-integration">DeepSpeed</a><a id="deepspeed"></a>
| <a href="./deepspeed#deepspeed-installation">Installation</a><a id="installation"></a>
| <a href="./deepspeed#deepspeed-multi-gpu">Deployment with multiple GPUs</a><a id="deployment-with-multiple-gpus"></a>
| <a href="./deepspeed#deepspeed-one-gpu">Deployment with one GPU</a><a id="deployment-with-one-gpu"></a>
| <a href="./deepspeed#deepspeed-notebook">Deployment in Notebooks</a><a id="deployment-in-notebooks"></a>
| <a href="./deepspeed#deepspeed-config">Configuration</a><a id="configuration"></a>
| <a href="./deepspeed#deepspeed-config-passing">Passing Configuration</a><a id="passing-configuration"></a>
| <a href="./deepspeed#deepspeed-config-shared">Shared Configuration</a><a id="shared-configuration"></a>
| <a href="./deepspeed#deepspeed-zero">ZeRO</a><a id="zero"></a>
| <a href="./deepspeed#deepspeed-zero2-config">ZeRO-2 Config</a><a id="zero-2-config"></a>
| <a href="./deepspeed#deepspeed-zero3-config">ZeRO-3 Config</a><a id="zero-3-config"></a>
| <a href="./deepspeed#deepspeed-nvme">NVMe Support</a><a id="nvme-support"></a>
| <a href="./deepspeed#deepspeed-zero2-zero3-performance">ZeRO-2 vs ZeRO-3 Performance</a><a id="zero-2-vs-zero-3-performance"></a>
| <a href="./deepspeed#deepspeed-zero2-example">ZeRO-2 Example</a><a id="zero-2-example"></a>
| <a href="./deepspeed#deepspeed-zero3-example">ZeRO-3 Example</a><a id="zero-3-example"></a>
| <a href="./deepspeed#deepspeed-optimizer">Optimizer</a><a id="optimizer"></a>
| <a href="./deepspeed#deepspeed-scheduler">Scheduler</a><a id="scheduler"></a>
| <a href="./deepspeed#deepspeed-fp32">fp32 Precision</a><a id="fp32-precision"></a>
| <a href="./deepspeed#deepspeed-amp">Automatic Mixed Precision</a><a id="automatic-mixed-precision"></a>
| <a href="./deepspeed#deepspeed-bs">Batch Size</a><a id="batch-size"></a>
| <a href="./deepspeed#deepspeed-grad-acc">Gradient Accumulation</a><a id="gradient-accumulation"></a>
| <a href="./deepspeed#deepspeed-grad-clip">Gradient Clipping</a><a id="gradient-clipping"></a>
| <a href="./deepspeed#deepspeed-weight-extraction">Getting The Model Weights Out</a><a id="getting-the-model-weights-out"></a>
]