Arde/fsdp activation checkpointing (#25771)

* add FSDP config option to enable activation-checkpointing

* update docs

* add checks and remove redundant code

* fix formatting error
This commit is contained in:
Arup De 2023-08-29 00:22:14 -07:00 committed by GitHub
parent 50573c648a
commit 738ecd17d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 0 deletions

View File

@ -456,6 +456,10 @@ as the model saving with FSDP activated is only available with recent fixes.
If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass.
- `limit_all_gathers` can be specified in the config file.
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
- `activation_checkpointing` can be specified in the config file.
If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
for reduced memory usage.
**Few caveats to be aware of**
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate`

View File

@ -3896,6 +3896,15 @@ class Trainer:
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:

View File

@ -482,6 +482,10 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If True, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may