mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[deepspeed docs] memory requirements (#15506)
This commit is contained in:
parent
f1a4c4ead5
commit
21dcaec5d5
@ -1623,12 +1623,68 @@ deepspeed examples/pytorch/translation/run_translation.py \
|
||||
Since for inference there is no need for additional large memory used by the optimizer states and the gradients you
|
||||
should be able to fit much larger batches and/or sequence length onto the same hardware.
|
||||
|
||||
|
||||
Additionally DeepSpeed is currently developing a related product called Deepspeed-Inference which has no relationship
|
||||
to the ZeRO technology, but instead uses tensor parallelism to scale models that can't fit onto a single GPU. This is a
|
||||
work in progress and we will provide the integration once that product is complete.
|
||||
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
Since Deepspeed ZeRO can offload memory to CPU (and NVMe) the framework provides utils that allow one to tell how much CPU and GPU memory will be needed depending on the number of GPUs being used.
|
||||
|
||||
Let's estimate how much memory is needed to finetune "bigscience/T0_3B" on a single GPU:
|
||||
|
||||
```bash
|
||||
$ python -c 'from transformers import AutoModel; \
|
||||
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
|
||||
model = AutoModel.from_pretrained("bigscience/T0_3B"); \
|
||||
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1)'
|
||||
[...]
|
||||
Estimated memory needed for params, optim states and gradients for a:
|
||||
HW: Setup with 1 node, 1 GPU per node.
|
||||
SW: Model with 2783M total params, 65M largest layer params.
|
||||
per CPU | per GPU | Options
|
||||
70.00GB | 0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
|
||||
70.00GB | 0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
|
||||
62.23GB | 5.43GB | offload_param=none, offload_optimizer=cpu , zero_init=1
|
||||
62.23GB | 5.43GB | offload_param=none, offload_optimizer=cpu , zero_init=0
|
||||
0.37GB | 46.91GB | offload_param=none, offload_optimizer=none, zero_init=1
|
||||
15.56GB | 46.91GB | offload_param=none, offload_optimizer=none, zero_init=0
|
||||
```
|
||||
|
||||
So you can fit it on a single 80GB GPU and no CPU offload, or a tiny 8GB GPU but then need ~60GB of CPU memory. (Remember this is just the memory for params, optimizer states and gradients - you will need a bit more memory for cuda kernels, activations and temps.)
|
||||
|
||||
Then it's a tradeoff of cost vs speed. It'll be cheaper to buy/rent a smaller GPU (or less GPUs since you can use multiple GPUs with Deepspeed ZeRO. But then it'll be slower, so even if you don't care about how fast something will be done, the slowdown has a direct impact on the duration of using the GPU and thus bigger cost. So experiment and compare which works the best.
|
||||
|
||||
If you have enough GPU memory make sure to disable the CPU/NVMe offload as it'll make everything faster.
|
||||
|
||||
For example, let's repeat the same for 2 GPUs:
|
||||
|
||||
```bash
|
||||
$ python -c 'from transformers import AutoModel; \
|
||||
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
|
||||
model = AutoModel.from_pretrained("bigscience/T0_3B"); \
|
||||
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=2, num_nodes=1)'
|
||||
[...]
|
||||
Estimated memory needed for params, optim states and gradients for a:
|
||||
HW: Setup with 1 node, 2 GPUs per node.
|
||||
SW: Model with 2783M total params, 65M largest layer params.
|
||||
per CPU | per GPU | Options
|
||||
70.00GB | 0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
|
||||
70.00GB | 0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
|
||||
62.23GB | 2.84GB | offload_param=none, offload_optimizer=cpu , zero_init=1
|
||||
62.23GB | 2.84GB | offload_param=none, offload_optimizer=cpu , zero_init=0
|
||||
0.74GB | 23.58GB | offload_param=none, offload_optimizer=none, zero_init=1
|
||||
31.11GB | 23.58GB | offload_param=none, offload_optimizer=none, zero_init=0
|
||||
|
||||
```
|
||||
|
||||
So here you'd want 2x 32GB GPUs or higher without offloading to CPU.
|
||||
|
||||
For full information please see [memory estimators](https://deepspeed.readthedocs.io/en/latest/memory.html).
|
||||
|
||||
|
||||
|
||||
### Filing Issues
|
||||
|
||||
Here is how to file an issue so that we could quickly get to the bottom of the issue and help you to unblock your work.
|
||||
@ -1683,9 +1739,7 @@ If the `deepspeed` process gets killed at launch time without a traceback, that
|
||||
to allocate more CPU memory than your system has or your process is allowed to allocate and the OS kernel killed that
|
||||
process. This is because your configuration file most likely has either `offload_optimizer` or `offload_param` or
|
||||
both configured to offload to `cpu`. If you have NVMe, experiment with offloading to NVMe if you're running under
|
||||
ZeRO-3.
|
||||
|
||||
Work is being done to enable estimating how much memory is needed for a specific model: [PR](https://github.com/microsoft/DeepSpeed/pull/965).
|
||||
ZeRO-3. Here is how you can [estimate how much memory is needed for a specific model](https://deepspeed.readthedocs.io/en/latest/memory.html).
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user