mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[performance doc] Power and Cooling (#14935)
* [performance doc] Power and Cooling * more docs * Update docs/source/performance.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * reword Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
3e9fdcf019
commit
37bc0b4e53
@ -42,24 +42,48 @@ Hardware:
|
||||
|
||||
- bigger GPUs
|
||||
- more GPUs
|
||||
- more CPU and NVMe (offloaded to by DeepSpeed)
|
||||
- more CPU and NVMe (offloaded to by [DeepSpeed-Infinity](deepspeed#nvme-support))
|
||||
|
||||
Software:
|
||||
|
||||
- Deepspeed ZeRO
|
||||
- Deepspeed ZeRO-Offload
|
||||
- Megatron-LM 3D Parallelism
|
||||
- Pipeline Parallelism
|
||||
- Tensor Parallelism
|
||||
- Model Scalability (ZeRO and 3D Parallelism)
|
||||
- Low-memory Optimizers
|
||||
- fp16/bf16 (smaller data/faster throughput)
|
||||
- tf32 (faster throughput)
|
||||
- Gradient accumulation
|
||||
- Gradient checkpointing
|
||||
- Sparsity
|
||||
|
||||
|
||||
## Hardware
|
||||
|
||||
|
||||
### Power and Cooling
|
||||
|
||||
If you bought an expensive high end GPU make sure you give it the correct power and sufficient cooling.
|
||||
|
||||
**Power**:
|
||||
|
||||
Some high end consumer GPU cards have 2 and sometimes 3 PCI-E 8-Pin power sockets. Make sure you have as many independent 12V PCI-E 8-Pin cables plugged into the card as there are sockets. Do not use the 2 splits at one end of the same cable (also known as pigtail cable). That is if you have 2 sockets on the GPU, you want 2 PCI-E 8-Pin cables going from your PSU to the card and not one that has 2 PCI-E 8-Pin connectors at the end! You won't get the full performance out of your card otherwise.
|
||||
|
||||
Each PCI-E 8-Pin power cable needs to be plugged into a 12V rail on the PSU side and can supply up to 150W of power.
|
||||
|
||||
Some other cards may use a PCI-E 12-Pin connectors, and these can deliver up to 500-600W of power.
|
||||
|
||||
Low end cards may use 6-Pin connectors, which supply up to 75W of power.
|
||||
|
||||
Additionally you want the high-end PSU that has stable voltage. Some lower quality ones may not give the card the stable voltage it needs to function at its peak.
|
||||
|
||||
And of course the PSU needs to have enough unused Watts to power the card.
|
||||
|
||||
**Cooling**:
|
||||
|
||||
When a GPU gets overheated it would start throttling down and will not deliver full performance. And it will shutdown if it gets too hot.
|
||||
|
||||
It's hard to tell the exact best temperature to strive for when a GPU is heavily loaded, but probably anything under +80C is good, but lower is better - perhaps 70-75C is an excellent range to be in. The throttling down is likely to start at around 84-90C. But other than throttling performance a prolonged very higher temperature is likely to reduce the lifespan of a GPU.
|
||||
|
||||
|
||||
|
||||
### Multi-GPU Connectivity
|
||||
|
||||
If you use multiple GPUs the way cards are inter-connected can have a huge impact on the total training time.
|
||||
@ -163,6 +187,14 @@ Software: `pytorch-1.8-to-be` + `cuda-11.0` / `transformers==4.3.0.dev0`
|
||||
## Software
|
||||
|
||||
|
||||
### Model Scalability
|
||||
|
||||
When you can't fit a model into the available GPU memory, you need to start using a solution that allows you to scale a large model to use multiple GPUs in parallel.
|
||||
|
||||
For indepth details on ZeRO and various other model parallelism protocols please see: [Model Parallelism](parallelism)
|
||||
|
||||
|
||||
|
||||
### Anatomy of Model's Operations
|
||||
|
||||
Transformers architecture includes 3 main groups of operations grouped below by compute-intensity.
|
||||
@ -307,6 +339,12 @@ Some amazing tutorials to read on mixed precision:
|
||||
- @sgugger wrote a great explanation of mixed precision [here](https://docs.fast.ai/callback.fp16.html#A-little-bit-of-theory)
|
||||
- Aleksey Bilogur's [A developer-friendly guide to mixed precision training with PyTorch](https://spell.ml/blog/mixed-precision-training-with-pytorch-Xuk7YBEAACAASJam)
|
||||
|
||||
You can also see a variety of benchmarks on fp16 vs other precisions:
|
||||
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
|
||||
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).
|
||||
|
||||
|
||||
|
||||
##### fp16 caching
|
||||
|
||||
pytorch `autocast` which performs AMP include a caching feature, which speed things up by caching fp16-converted values. Here is the full description from this [comment](https://discuss.pytorch.org/t/autocast-and-torch-no-grad-unexpected-behaviour/93475/3):
|
||||
@ -356,6 +394,10 @@ python -c 'import transformers; print(f"BF16 support is {transformers.file_utils
|
||||
|
||||
On the other hand bf16 has a much worse precision than fp16, so there are certain situations where you'd still want to use fp16 and not bf16.
|
||||
|
||||
You can also see a variety of benchmarks on bf16 vs other precisions:
|
||||
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
|
||||
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).
|
||||
|
||||
|
||||
##### bf16 Inference
|
||||
|
||||
@ -388,10 +430,26 @@ Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.t
|
||||
|
||||
Note: you need `torch>=1.7` to enjoy this feature.
|
||||
|
||||
You can also see a variety of benchmarks on tf32 vs other precisions:
|
||||
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
|
||||
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).
|
||||
|
||||
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
Since gradient accumulation essentially is identical to having a larger batch size, just as with the larger batch size here you are likely to see a 20-30% speedup due to the optimizer running less often. For example, see benchmarks for [RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004392537)
|
||||
and [A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004592231).
|
||||
|
||||
To activate this feature in 🤗 Trainer add `--gradient_accumulation_steps 4` to its arguments (experiment with the value to get the best performance).
|
||||
|
||||
It's important to remember that using gradient accumulation you may end up with a much larger effective batch size, so you may need to adjust the learning rate, its warm up and for very short datasets it'll impact the loss as the training will end up doing less steps than normal.
|
||||
|
||||
|
||||
|
||||
### Gradient Checkpointing
|
||||
|
||||
One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.
|
||||
One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation. The slowdown will depend on the model but quite often it is around 20-30%.
|
||||
|
||||
This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers.
|
||||
|
||||
@ -414,6 +472,10 @@ https://docs.nvidia.com/deeplearning/performance/dl-performance-fully-connected/
|
||||
|
||||
For parameters that are small, there is also [Dimension Quantization Effects](https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#dim-quantization) to consider, this is where tiling happens and the right multiplier can have a significant speedup.
|
||||
|
||||
Additionally, as explained in the [Gradient Accumulation](#gradient-accumulation) section, the bigger the batch size the less often the optimizer is run, the faster the training is (considering the same dataset length). See benchmarks
|
||||
for [RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004392537)
|
||||
and [A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1005033957).
|
||||
|
||||
|
||||
### DP vs DDP
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user