add warning when using gradient_checkpointing with FSDP full shard (#31578)

* add warning when using  with FSDP full shard

* fix style

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add hybrid shard warn

* fix style

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yun Dai 2024-07-09 15:55:57 -07:00 committed by GitHub
parent 6176d8f5ee
commit ad35309a62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1820,7 +1820,7 @@ class TrainingArguments:
raise ValueError("warmup_steps must be either 0 or > 1")
if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
@ -1831,6 +1831,15 @@ class TrainingArguments:
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)
if self.fsdp_config is None:
self.fsdp_config = {}