Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)

* add gradient_accumulation_kwargs to AcceleratorConfig

* add suggestions from @muellerzr to docstrings, new behavior and tests

* Documentation suggestions from @muellerz

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* addressed @muellerzr comments regarding tests and test utils

* moved accelerate version to top of file.

* @muellerzr's variable fix

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* address @amyeroberts. fix tests and docstrings

* address @amyeroberts additional suggestions

---------

Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Yu Chin Fabian Lim 2024-03-28 22:01:40 +08:00 committed by GitHub
parent a2a7f71604
commit 4df5b9b4b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 13 deletions

View File

@ -52,6 +52,7 @@ from .integrations import (
)
from .integrations.deepspeed import is_deepspeed_available
from .utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
@ -365,11 +366,13 @@ def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
def require_accelerate(test_case):
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return unittest.skipUnless(
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
)(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"):

View File

@ -4324,8 +4324,23 @@ class Trainer:
self.repo.git_push()
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict()
@ -4337,6 +4352,8 @@ class Trainer:
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,

View File

@ -1185,6 +1185,15 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
the latter is set to 1, otherwise an exception will be raised.
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
"""
@ -1223,6 +1232,19 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
"help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
"Any of the following (optional) keys are acceptable: "
" num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
" the latter is set to 1, otherwise an exception will be raised. "
" adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
" sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
},
)
@classmethod
def from_json_file(cls, json_file):

View File

@ -805,9 +805,7 @@ def is_protobuf_available():
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):

View File

@ -24,6 +24,7 @@ import subprocess
import sys
import tempfile
import unittest
from functools import partial
from itertools import product
from pathlib import Path
from typing import Dict, List
@ -92,6 +93,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
@ -127,6 +129,9 @@ if is_torch_available():
if is_safetensors_available():
import safetensors.torch
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
# gradient accumulation kwargs configures gradient_state
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
accelerator_config=accelerator_config,
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
"even_batches": False,
"use_seedable_sampler": False,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
@require_accelerate_version_min_0_28
def test_accelerate_config_from_dataclass_grad_accum(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
grad_acc_kwargs = {
"num_steps": 10,
"adjust_scheduler": False,
"sync_with_dataloader": False,
"sync_each_batch": True,
}
accelerator_config = AcceleratorConfig(
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
gradient_accumulation_kwargs=grad_acc_kwargs,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
@require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] == 1
# results in grad accum set to 1
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 1,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
# case - TrainingArguments.gradient_accumulation_steps > 1
# - gradient_accumulation_kwargs['num_steps] specified
# results in exception raised
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=2,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
with self.assertRaises(Exception) as context:
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
@require_torch
@is_staging_test