mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
FSDP grad accum fix (#34645)
* add gradient accumulation steps tests for fsdp * invert no_sync context to fix training for fsdp
This commit is contained in:
parent
52ea4aa589
commit
b0c0ba7b4d
@ -2488,7 +2488,7 @@ class Trainer:
|
|||||||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
||||||
context = (
|
context = (
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
if i == len(batch_samples) - 1
|
if i != len(batch_samples) - 1
|
||||||
else contextlib.nullcontext
|
else contextlib.nullcontext
|
||||||
)
|
)
|
||||||
with context():
|
with context():
|
||||||
|
@ -224,6 +224,18 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
cmd = launcher + script + args + fsdp_args
|
cmd = launcher + script + args + fsdp_args
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
|
@require_torch_multi_accelerator
|
||||||
|
@slow
|
||||||
|
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
|
||||||
|
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args = self.get_base_args(output_dir, 1, 50).split() + [f"--{dtype}", "--gradient_accumulation_steps", "2"]
|
||||||
|
fsdp_args = ["--fsdp", f"{sharding_strategy} auto_wrap", "--fsdp_transformer_layer_cls_to_wrap", "BertLayer"]
|
||||||
|
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
|
||||||
|
cmd = launcher + script + args + fsdp_args
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
@parameterized.expand(dtypes)
|
@parameterized.expand(dtypes)
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
@slow
|
@slow
|
||||||
|
Loading…
Reference in New Issue
Block a user