mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[deepspeed docs] Activation Checkpointing (#22099)
* [deepspeed docs] Activation Checkpointing * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update deepspeed.mdx --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
5b85add7d5
commit
618697ef53
@ -1198,6 +1198,20 @@ Other quick related performance notes:
|
||||
- if you are training something from scratch always try to have tensors with shapes that are divisible by 16 (e.g. hidden size). For batch size try divisible by 2 at least. There are [wave and tile quanitization](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/) divisibility that is hardware-specific if you want to squeeze even higher performance from your GPUs.
|
||||
|
||||
|
||||
### Activation Checkpointing or Gradient Checkpointing
|
||||
|
||||
Activation checkpointing and gradient checkpointing are two distinct terms that refer to the same methodology. It's very confusing but this is how it is.
|
||||
|
||||
Gradient checkpointing allows one to trade speed for GPU memory, which either allows one to overcome a GPU OOM, or increase their batch size, which often leads to a better performance.
|
||||
|
||||
HF Transformers models don't know anything about DeepSpeed's activation checkpointing, so if you try to enable that feature in the DeepSpeed config file, nothing will happen.
|
||||
|
||||
Therefore you have two ways to take advantage of this very beneficial feature:
|
||||
|
||||
1. If you want to use a HF Transformers models you can do `model.gradient_checkpointing_enable()` or use `--gradient_checkpointing` in the HF Trainer, which will automatically enable this for you. `torch.utils.checkpoint` is used there.
|
||||
2. If you write your own model and you want to use DeepSpeed's activation checkpointing you can use the [API prescribed there](https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html). You can also take the HF Transformers modeling code and replace `torch.utils.checkpoint` with the DeepSpeed's API. The latter is more flexible since it allows you to offload the forward activations to the CPU memory instead of recalculating them.
|
||||
|
||||
|
||||
### Optimizer and Scheduler
|
||||
|
||||
As long as you don't enable `offload_optimizer` you can mix and match DeepSpeed and HuggingFace schedulers and
|
||||
|
Loading…
Reference in New Issue
Block a user