mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)
* add gradient_accumulation_kwargs to AcceleratorConfig * add suggestions from @muellerzr to docstrings, new behavior and tests * Documentation suggestions from @muellerz Co-authored-by: Zach Mueller <muellerzr@gmail.com> * addressed @muellerzr comments regarding tests and test utils * moved accelerate version to top of file. * @muellerzr's variable fix Co-authored-by: Zach Mueller <muellerzr@gmail.com> * address @amyeroberts. fix tests and docstrings * address @amyeroberts additional suggestions --------- Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
a2a7f71604
commit
4df5b9b4b2
@ -52,6 +52,7 @@ from .integrations import (
|
||||
)
|
||||
from .integrations.deepspeed import is_deepspeed_available
|
||||
from .utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_aqlm_available,
|
||||
@ -365,11 +366,13 @@ def require_nltk(test_case):
|
||||
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
return unittest.skipUnless(
|
||||
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_fsdp(test_case, min_version: str = "1.12.0"):
|
||||
|
@ -4324,8 +4324,23 @@ class Trainer:
|
||||
self.repo.git_push()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
|
||||
grad_acc_kwargs = {}
|
||||
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
|
||||
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
|
||||
|
||||
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
||||
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
|
||||
# raise because we do not know which setting is intended.
|
||||
raise ValueError(
|
||||
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
||||
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
||||
)
|
||||
elif "num_steps" not in grad_acc_kwargs:
|
||||
# take the gradient_accumulation_steps setting from TrainingArguments.
|
||||
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
|
||||
|
||||
grad_acc_kwargs["sync_with_dataloader"] = False
|
||||
|
||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
@ -4337,6 +4352,8 @@ class Trainer:
|
||||
even_batches=accelerator_config.pop("even_batches"),
|
||||
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
||||
)
|
||||
# this would have been updated above, no need for it anymore
|
||||
accelerator_config.pop("gradient_accumulation_kwargs")
|
||||
args = {
|
||||
"deepspeed_plugin": self.args.deepspeed_plugin,
|
||||
"gradient_accumulation_plugin": gradient_accumulation_plugin,
|
||||
|
@ -1185,6 +1185,15 @@ class AcceleratorConfig:
|
||||
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||
also be ran with [`~utils.set_seed`] for the best results.
|
||||
gradient_accumulation_kwargs (`dict`, *optional*):
|
||||
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
|
||||
Any of the following (optional) keys are acceptable:
|
||||
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
|
||||
the latter is set to 1, otherwise an exception will be raised.
|
||||
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
|
||||
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
|
||||
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
|
||||
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
|
||||
|
||||
"""
|
||||
|
||||
@ -1223,6 +1232,19 @@ class AcceleratorConfig:
|
||||
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
||||
},
|
||||
)
|
||||
gradient_accumulation_kwargs: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
|
||||
"Any of the following (optional) keys are acceptable: "
|
||||
" num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
|
||||
" the latter is set to 1, otherwise an exception will be raised. "
|
||||
" adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
|
||||
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
|
||||
" sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
|
||||
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
|
@ -805,9 +805,7 @@ def is_protobuf_available():
|
||||
|
||||
|
||||
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
|
||||
if min_version is not None:
|
||||
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
||||
return _accelerate_available
|
||||
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
|
||||
|
@ -24,6 +24,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
@ -92,6 +93,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_safetensors_available,
|
||||
@ -127,6 +129,9 @@ if is_torch_available():
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
# gradient accumulation kwargs configures gradient_state
|
||||
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
|
||||
|
||||
def test_accelerator_config_from_dict(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
accelerator_config = {
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": True,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
|
||||
# Leaves all options as something *not* basic
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config={
|
||||
"split_batches": True,
|
||||
"dispatch_batches": True,
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": True,
|
||||
},
|
||||
accelerator_config=accelerator_config,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_yaml(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
"even_batches": False,
|
||||
"use_seedable_sampler": False,
|
||||
}
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
|
||||
json.dump(accelerator_config, f)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_dataclass(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
|
||||
accelerator_config = AcceleratorConfig(
|
||||
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
|
||||
split_batches=True,
|
||||
dispatch_batches=True,
|
||||
even_batches=False,
|
||||
use_seedable_sampler=False,
|
||||
)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||
|
||||
@require_accelerate_version_min_0_28
|
||||
def test_accelerate_config_from_dataclass_grad_accum(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
|
||||
grad_acc_kwargs = {
|
||||
"num_steps": 10,
|
||||
"adjust_scheduler": False,
|
||||
"sync_with_dataloader": False,
|
||||
"sync_each_batch": True,
|
||||
}
|
||||
accelerator_config = AcceleratorConfig(
|
||||
split_batches=True,
|
||||
dispatch_batches=True,
|
||||
even_batches=False,
|
||||
use_seedable_sampler=False,
|
||||
gradient_accumulation_kwargs=grad_acc_kwargs,
|
||||
)
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
|
||||
|
||||
def test_accelerator_config_from_partial(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
@require_accelerate_version_min_0_28
|
||||
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# case - TrainingArguments.gradient_accumulation_steps == 1
|
||||
# - gradient_accumulation_kwargs['num_steps] == 1
|
||||
# results in grad accum set to 1
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
gradient_accumulation_steps=1,
|
||||
accelerator_config={
|
||||
"gradient_accumulation_kwargs": {
|
||||
"num_steps": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
|
||||
|
||||
# case - TrainingArguments.gradient_accumulation_steps > 1
|
||||
# - gradient_accumulation_kwargs['num_steps] specified
|
||||
# results in exception raised
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
gradient_accumulation_steps=2,
|
||||
accelerator_config={
|
||||
"gradient_accumulation_kwargs": {
|
||||
"num_steps": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
Loading…
Reference in New Issue
Block a user