FSDP grad accum fix (#34645)

* add gradient accumulation steps tests for fsdp

* invert no_sync context to fix training for fsdp
This commit is contained in:
Wing Lian 2024-11-15 16:28:06 -05:00 committed by GitHub
parent 52ea4aa589
commit b0c0ba7b4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -2488,7 +2488,7 @@ class Trainer:
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():

View File

@ -224,6 +224,18 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@slow
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
cmd = launcher + script + args + fsdp_args
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(dtypes)
@require_torch_multi_accelerator
@slow