mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add model scaling section (#15119)
* add model scaling section * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * integrate reviewer feedback * initialize GPU properly * add note about BnB optimizer * move doc from `scaling.mdx` to `performance.mdx` * integrate reviewer feedback * revert section levels Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b5c6fdecf0
commit
d923f76203
@ -16,13 +16,479 @@ limitations under the License.
|
||||
|
||||
# Performance and Scalability: How To Fit a Bigger Model and Train It Faster
|
||||
|
||||
For now the software sections of this document are mainly Pytorch-specific, but the guide can be extended to other frameworks in the future.
|
||||
> _Or how to escape the dreaded "RuntimeError: CUDA error: out of memory" error._
|
||||
|
||||
## Quick notes
|
||||
[[open-in-colab]]
|
||||
|
||||
Training ever larger models can become challenging even on modern GPUs. Due to their immense size we often run out of GPU memory and training can take very long. In this section we have a look at a few tricks to reduce the memory footprint and speed up training for large models and how they are integrated in the [`Trainer`] and [🤗 Accelerate](https://huggingface.co/docs/accelerate/). Before we start make sure you have installed the following libraries:
|
||||
|
||||
```bash
|
||||
pip install transformers datasets accelerate nvidia-ml-py3
|
||||
```
|
||||
|
||||
The `nvidia-ml-py3` library allows us to monitor the memory usage of the models from within Python. You might be familiar with the `nvidia-smi` command in the terminal - this library allows to access the same information in Python directly.
|
||||
|
||||
Then we create some dummy data. We create random token IDs between 100 and 30000 and binary labels for a classifier. In total we get 512 sequences each with length 512 and store them in a [`Dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=dataset#datasets.Dataset) with PyTorch format.
|
||||
|
||||
|
||||
```py
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
seq_len, dataset_size = 512, 512
|
||||
dummy_data = {
|
||||
"input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
|
||||
"labels": np.random.randint(0, 1, (dataset_size)),
|
||||
}
|
||||
ds = Dataset.from_dict(dummy_data)
|
||||
ds.set_format("pt")
|
||||
```
|
||||
|
||||
We want to print some summary statistics for the GPU utilization and the training run with the [`Trainer`]. We setup a two helper functions to do just that:
|
||||
|
||||
```py
|
||||
from pynvml import *
|
||||
|
||||
|
||||
def print_gpu_utilization():
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
print(f"GPU memory occupied: {info.used//1024**2} MB.")
|
||||
|
||||
|
||||
def print_summary(result):
|
||||
print(f"Time: {result.metrics['train_runtime']:.2f}")
|
||||
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
|
||||
print_gpu_utilization()
|
||||
```
|
||||
|
||||
Let's verify that we start with a free GPU memory:
|
||||
|
||||
```py
|
||||
>>> print_gpu_utilization()
|
||||
GPU memory occupied: 0 MB.
|
||||
```
|
||||
|
||||
That looks good: the GPU memory is not occupied as we would expect before we load any models. If that's not the case on your machine make sure to stop all processes that are using GPU memory. However, not all free GPU memory can be used by the user. When a model is loaded to the GPU also the kernels are loaded which can take up 1-2GB of memory. To see how much it is we load a tiny tensor into the GPU which triggers the kernels to be loaded as well.
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> torch.ones((1, 1)).to("cuda")
|
||||
>>> print_gpu_utilization()
|
||||
GPU memory occupied: 1343 MB.
|
||||
```
|
||||
|
||||
We see that the kernels alone take up 1.3GB of GPU memory. Now let's see how much space the model uses.
|
||||
|
||||
## Load Model
|
||||
|
||||
First, we load the `bert-large-uncased` model. We load the model weights directly to the GPU so that we can check how much space just weights use.
|
||||
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased").to("cuda")
|
||||
>>> print_gpu_utilization()
|
||||
GPU memory occupied: 2631 MB.
|
||||
```
|
||||
|
||||
We can see that the model weights alone take up 1.3 GB of the GPU memory. The exact number depends on the specific GPU you are using. Note that on newer GPUs a model can sometimes take up more space since the weights are loaded in an optimized fashion that speeds up the usage of the model. Now we can also quickly check if we get the same result as with `nvidia-smi` CLI:
|
||||
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
```bash
|
||||
Tue Jan 11 08:58:05 2022
|
||||
+-----------------------------------------------------------------------------+
|
||||
| NVIDIA-SMI 460.91.03 Driver Version: 460.91.03 CUDA Version: 11.2 |
|
||||
|-------------------------------+----------------------+----------------------+
|
||||
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
|
||||
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|
||||
| | | MIG M. |
|
||||
|===============================+======================+======================|
|
||||
| 0 Tesla V100-SXM2... On | 00000000:00:04.0 Off | 0 |
|
||||
| N/A 37C P0 39W / 300W | 2631MiB / 16160MiB | 0% Default |
|
||||
| | | N/A |
|
||||
+-------------------------------+----------------------+----------------------+
|
||||
|
||||
+-----------------------------------------------------------------------------+
|
||||
| Processes: |
|
||||
| GPU GI CI PID Type Process name GPU Memory |
|
||||
| ID ID Usage |
|
||||
|=============================================================================|
|
||||
| 0 N/A N/A 3721 C ...nvs/codeparrot/bin/python 2629MiB |
|
||||
+-----------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
We get the same number as before and you can also see that we are using a V100 GPU with 16GB of memory. So now we can start training the model and see how the GPU memory consumption changes. First, we set up a few standard training arguments that we will use across all our experiments:
|
||||
|
||||
```py
|
||||
default_args = {
|
||||
"output_dir": "tmp",
|
||||
"evaluation_strategy": "steps",
|
||||
"num_train_epochs": 1,
|
||||
"log_level": "error",
|
||||
"report_to": "none",
|
||||
}
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Note: In order to properly clear the memory after experiments we need restart the Python kernel between experiments. Run all steps above and then just one of the experiments below.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Vanilla Training
|
||||
|
||||
As a first experiment we will use the [`Trainer`] and train the model without any further modifications and a batch size of 4:
|
||||
|
||||
```py
|
||||
from transformers import TrainingArguments, Trainer, logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 57.82
|
||||
Samples/second: 8.86
|
||||
GPU memory occupied: 14949 MB.
|
||||
```
|
||||
|
||||
We see that already a relatively small batch size almost fills up our GPU's entire memory. However, a larger batch size can often result in faster model convergence or better end performance. So ideally we want to tune the batch size to our model's needs and not to the GPU limitations. A simple trick to effectively train larger batch size is gradient accumulation.
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
The idea behind gradient accumulation is to instead of calculating the gradients for the whole batch at once to do it in smaller steps. The way we do that is to calculate the gradients iteratively in smaller batches by doing a forward and backward pass through the model and accumulating the gradients in the process. When enough gradients are accumulated we run the model's optimization step. This way we can easily increase the overall batch size to numbers that would never fit into the GPU's memory. In turn, however, the added forward and backward passes can slow down the training a bit.
|
||||
|
||||
We can use gradient accumulation in the [`Trainer`] by simply adding the `gradient_accumulation_steps` argument to [`TrainingArguments`]. Let's see how it impacts the models memory footprint:
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(per_device_train_batch_size=1, gradient_accumulation_steps=4, **default_args)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 66.03
|
||||
Samples/second: 7.75
|
||||
GPU memory occupied: 8681 MB.
|
||||
```
|
||||
|
||||
We can see that the memory footprint was dramatically reduced at the cost of being only slightly slower than the vanilla run. Of course, this would change as you increase the number of accumulation steps. In general you would want to max out the GPU usage as much as possible. So in our case, the batch_size of 4 was already pretty close to the GPU's limit. If we wanted to train with a batch size of 64 we should not use `per_device_train_batch_size=1` and `gradient_accumulation_steps=64` but instead `per_device_train_batch_size=4` and `gradient_accumulation_steps=16` which has the same effective batch size while making better use of the available GPU resources.
|
||||
|
||||
Next we have a look at another trick to save a little bit more GPU memory called gradient checkpointing.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
Even when we set the batch size to 1 and use gradient accumulation we can still run out of memory when working with large models. In order to compute the gradients during the backward pass all activations from the forward pass are normally saved. This can create a big memory overhead. Alternatively, one could forget all activations during the forward pass and recompute them on demand during the backward pass. This would however add a significant computational overhead and slow down training.
|
||||
|
||||
Gradient checkpointing strikes a compromise between the two approaches and saves strategically selected activations throughout the computational graph so only a fraction of the activations need to be re-computed for the gradients. See [this great article](https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9) explaining the ideas behind gradient checkpointing.
|
||||
|
||||
To enable gradient checkpointing in the [`Trainer`] we only need ot pass it as a flag to the [`TrainingArguments`]. Everything else is handled under the hood:
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, **default_args
|
||||
)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 85.47
|
||||
Samples/second: 5.99
|
||||
GPU memory occupied: 6775 MB.
|
||||
```
|
||||
|
||||
We can see that this saved some more memory but at the same time training became a bit slower. A general rule of thumb is that gradient checkpointing slows down training by about 20%. Let's have a look at another method with which we can regain some speed: mixed precision training.
|
||||
|
||||
## FP16 Training
|
||||
|
||||
The idea of mixed precision training is that no all variables need to be stored in full (32-bit) floating point precision. If we can reduce the precision the variales and their computations are faster. The main advantage comes from saving the activations in half (16-bit) precision. Although the gradients are also computed in half precision they are converted back to full precision for the optimization step so no memory is saved here. Since the model is present on the GPU in both 16-bit and 32-bit precision this can use more GPU memory (1.5x the original model is on the GPU), especially for small batch sizes. Since some computations are performed in full and some in half precision this approach is also called mixed precision training. Enabling mixed precision training is also just a matter of setting the `fp16` flag to `True`:
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(per_device_train_batch_size=4, fp16=True, **default_args)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 27.46
|
||||
Samples/second: 18.64
|
||||
GPU memory occupied: 13939 MB.
|
||||
```
|
||||
|
||||
We can see that this is almost twice as fast as the vanilla training. Let's add it to the mix of the previous methods:
|
||||
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
fp16=True,
|
||||
**default_args,
|
||||
)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 50.76
|
||||
Samples/second: 10.09
|
||||
GPU memory occupied: 7275 MB.
|
||||
```
|
||||
|
||||
We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster. But we are not done, yet! There is another area where we can save GPU memory: the optimizer.
|
||||
|
||||
## Optimizer
|
||||
|
||||
The most common optimizer used to train transformer model is Adam or AdamW (Adam with weight decay). Adam achieves good convergence by storing the rolling average of the previous gradients which, however, adds an additional memory footprint of the order of the number of model parameters. One remedy to this is to use an alternative optimizer such as Adafactor.
|
||||
|
||||
### Adafactor
|
||||
|
||||
Instead of keeping the rolling average for each element in the weight matrices Adafactor only stores aggregated information (row- and column-wise sums of the rolling averages) which reduces the footprint considerably. One downside of Adafactor is that in some instances convergence can be slower than Adam's so some experimentation is advised here. We can use Adafactor simply by setting `optim="adafactor"`:
|
||||
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(per_device_train_batch_size=4, optim="adafactor", **default_args)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 64.31
|
||||
Samples/second: 7.96
|
||||
GPU memory occupied: 12295 MB.
|
||||
```
|
||||
|
||||
We can see that this saves a few more GB on the GPU. Let's see how it looks when we add it to the other methods we introduced earlier:
|
||||
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
fp16=True,
|
||||
optim="adafactor",
|
||||
**default_args,
|
||||
)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 56.54
|
||||
Samples/second: 9.06
|
||||
GPU memory occupied: 4847 MB.
|
||||
```
|
||||
|
||||
We went from 15 GB memory usage to 5 GB - a 3x improvement while maintaining the throughput! However, as mentioned before, the convergence of Adafactor can be worse than Adam. There is an alternative to Adafactor called 8-bit Adam that takes a slightly different approach.
|
||||
|
||||
### 8-bit Adam
|
||||
|
||||
Instead of aggregating optimizer states like Adafactor, 8-bit Adam keeps the full state and quantizes it. Quantization means that it stores the state with lower precision and dequantizes it only for the optimization. This is similar to the idea behind FP16 training where using variables with lower precision saves memory.
|
||||
|
||||
In contrast to the previous approaches is this one not integrated into the [`Trainer`] as a simple flag. We need to install the 8-bit optimizer and then pass it as a custom optimizer to the [`Trainer`]. Follow the installation guide in the Github [repo](https://github.com/facebookresearch/bitsandbytes) to install the `bitsandbytes` library that implements the 8-bit Adam optimizer.
|
||||
|
||||
Once installed, we just need to initialize the the optimizer. Although this looks like a considerable amount of work it actually just involves two steps: first we need to group the model's parameters into two groups where to one group we apply weight decay and to the other we don't. Usually, biases and layer norm parameters are not weight decayed. Then in a second step we just do some argument housekeeping to use the same parameters as the previously used AdamW optimizer.
|
||||
|
||||
<Tip>
|
||||
Note that in order to use the 8-bit optimizer with an existing pretrained model a change to the embedding layer is needed.
|
||||
Read [this issue](https://github.com/huggingface/transformers/issues/14819) for more information.
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
import bitsandbytes as bnb
|
||||
from torch import nn
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
||||
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||
"weight_decay": training_args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_kwargs = {
|
||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||
"eps": training_args.adam_epsilon,
|
||||
}
|
||||
optimizer_kwargs["lr"] = training_args.learning_rate
|
||||
adam_bnb_optim = bnb.optim.Adam8bit(
|
||||
optimizer_grouped_parameters,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
eps=training_args.adam_epsilon,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
```
|
||||
|
||||
We can now pass the custom optimizer as an argument to the `Trainer`:
|
||||
```py
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds, optimizers=(adam_bnb_optim, None))
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 55.95
|
||||
Samples/second: 9.15
|
||||
GPU memory occupied: 13085 MB.
|
||||
```
|
||||
|
||||
We can see that we get a similar memory improvement as with Adafactor while keeping the full rolling average of the gradients. Let's repeat the experiment with the full settings:
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
fp16=True,
|
||||
**default_args,
|
||||
)
|
||||
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=ds, optimizers=(adam_bnb_optim, None))
|
||||
result = trainer.train()
|
||||
print_summary(result)
|
||||
```
|
||||
|
||||
```
|
||||
Time: 49.46
|
||||
Samples/second: 10.35
|
||||
GPU memory occupied: 5363 MB.
|
||||
```
|
||||
|
||||
Again, we get about a 3x memory improvement and even slightly higher throughput as using Adafactor. So we have seen how we can optimize the memory footprint of large models. The following plot summarizes all our experiments:
|
||||
|
||||

|
||||
|
||||
## Using 🤗 Accelerate
|
||||
|
||||
So far we have used the [`Trainer`] to run the experiments but a more flexible alternative to that approach is to use 🤗 Accelerate. With 🤗 Accelerate you have full control over the training loop and can essentially write the loop in pure PyTorch with some minor modifications. In turn it allows you to easily scale across different infrastructures such as CPUs, GPUs, TPUs, or distributed multi-GPU setups without changing any code. Let's see what it takes to implement all of the above tweaks in 🤗 Accelerate. We can still use the [`TrainingArguments`] to wrap the training settings:
|
||||
|
||||
|
||||
```py
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
fp16=True,
|
||||
**default_args,
|
||||
)
|
||||
```
|
||||
|
||||
The full example training loop with 🤗 Accelerate is only a handful of lines of code long:
|
||||
|
||||
|
||||
```py
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
dataloader = DataLoader(ds, batch_size=training_args.per_device_train_batch_size)
|
||||
|
||||
if training_args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
accelerator = Accelerator(fp16=training_args.fp16)
|
||||
model, optimizer, dataloader = accelerator.prepare(model, adam_bnb_optim, dataloader)
|
||||
|
||||
model.train()
|
||||
for step, batch in enumerate(dataloader, start=1):
|
||||
loss = model(**batch).loss
|
||||
loss = loss / training_args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % training_args.gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
First we wrap the dataset in a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). Then we can enable gradient checkpointing by calling the model's [`~PreTrainedModel.gradient_checkpointing_enable`] method. When we initialize the [`Accelerator`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator) we can specifiy if we want to use mixed precision training and it will take care of it for us in the [`prepare`] call. During the [`prepare`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.prepare) call the dataloader will also be distributed across workers should we use multiple GPUs. We use the same 8-bit optimizer from the earlier experiments.
|
||||
|
||||
Finally, we can write the main training loop. Note that the `backward` call is handled by 🤗 Accelerate. We can also see how gradient accumulation works: we normalize the loss so we get the average at the end of accumulation and once we have enough steps we run the optimization. Now the question is: does this use the same amount of memory as the previous steps? Let's check:
|
||||
|
||||
|
||||
```py
|
||||
>>> print_gpu_utilization()
|
||||
GPU memory occupied: 5363 MB.
|
||||
```
|
||||
|
||||
|
||||
Indeed it does. Implementing these optimization techniques with 🤗 Accelerate only takes a handful of lines of code and comes with the benefit of more flexiblity in the training loop.
|
||||
|
||||
Now, let's take a step back and discuss what we should optimize for when scaling the training of large models.
|
||||
|
||||
## How to scale
|
||||
|
||||
When we train models there are a two aspects we want to optimize at the same time:
|
||||
|
||||
- Data throughput/training time
|
||||
- Model performance
|
||||
|
||||
We have seen that each method changes the memory usage and throughput. In general we want to maximize the throughput (samples/second) to minimize the training cost. This is generally achieved by utilizing the GPU as much as possible and thus filling GPU memory to its limit. For example, as mentioned earlier, we only employ gradient accumulation when we want to use a batch size beyond the size of the GPU memory. If the desired batch size fits into memory then there is no reason to apply gradient accumulation which will only slow down training.
|
||||
|
||||
The second objective is model performance. Just because we can does not mean we should use a large batch size. As part of hyperparameter tuning you should determine which batch size yields the best result and then optimize the throughput accordingly.
|
||||
|
||||
Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture. For some applications, such as pretraining, this might still not be fast enough. In this case you want to scale your experiment to several GPUs.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
If your model fits on a single GPU scaling to many GPUs can be achieved fairly easily with data parallelism. The idea is very similar to gradient accumulation with the distinction that instead of running the forward and backward passes during the accumulation in sequence on a single machine they are performed in parallel on multiple machines. So each GPU gets a small batch, runs the forward and backward passes and then the gradients from all machines are aggregated and the model is optimized. You can combine this with all the methods we described before. For example, if you have 4 GPUs and use `per_device_train_batch_size=12` and `gradient_accumulation_steps=3` you will have an effective batch size of `4*12*3=144`.
|
||||
|
||||
The [`Trainer`] allows for distributed training and if you execute your [`Trainer`] training script on a machine with multiple GPUs it will automatically utilize all of them, hence the name `per_device_train_batch_size`. In 🤗 Accelerate you can configure the infrastructure setup with the following command:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Until now we have opperated under the assumption that we can fit the model onto a single GPU without or with the introduced tricks . But what if this is not possible? We still have a few tricks up our sleeves!
|
||||
|
||||
## What if my model still does not fit?
|
||||
|
||||
If the model does not fit on a single GPU with all the mentioned tricks there are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Model Parallelism" section](parallelism).
|
||||
|
||||
This concludes the practical part of this guide for scaling the training of large models. The following section goes into more details on some of the aspects discussed above.
|
||||
|
||||
|
||||
## Further discussions
|
||||
|
||||
This section gives brief ideas on how to make training faster and support bigger models. Later sections will expand, demonstrate and elucidate each of these.
|
||||
|
||||
### Faster Training
|
||||
## Faster Training
|
||||
|
||||
Hardware:
|
||||
|
||||
@ -36,7 +502,7 @@ Software:
|
||||
- fp16 (autocast caching)
|
||||
|
||||
|
||||
### Bigger Models
|
||||
## Bigger Models
|
||||
|
||||
Hardware:
|
||||
|
||||
@ -570,7 +1036,7 @@ One of the important requirements to reach great training speed is the ability t
|
||||
- `DataLoader(pin_memory=True, ...)` which ensures that the data gets preloaded into the pinned memory on CPU and typically leads to much faster transfers from CPU to GPU memory.
|
||||
- `DataLoader(num_workers=4, ...)` - spawn several workers to pre-load data faster - during training watch the GPU utilization stats and if it's far from 100% experiment with raising the number of workers. Of course, the problem could be elsewhere so a very big number of workers won't necessarily lead to a better performance.
|
||||
|
||||
### Faster optimizer
|
||||
## Faster optimizer
|
||||
|
||||
pytorch-nightly introduced `torch.optim._multi_tensor` which should significantly speed up the optimizers for situations with lots of small feature tensors. It should eventually become the default, but if you want to experiment with it sooner and don't mind using the bleed-edge, see: https://github.com/huggingface/transformers/issues/9965
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user