Silence deprecations and use the DataLoaderConfig (#29779)

* Remove deprecations

* Clean
This commit is contained in:
Zach Mueller 2024-03-21 10:26:51 -04:00 committed by GitHub
parent de627f5a14
commit f0bfb150fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -221,6 +221,9 @@ if is_accelerate_available():
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
if is_accelerate_available("0.28.0"):
from accelerate.utils import DataLoaderConfiguration
def _is_peft_model(model):
if is_peft_available():
@ -4248,12 +4251,26 @@ class Trainer:
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict()
if is_accelerate_available("0.28.0"):
dataloader_config = DataLoaderConfiguration(
split_batches=accelerator_config.pop("split_batches"),
dispatch_batches=accelerator_config.pop("dispatch_batches"),
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
)
self.accelerator = Accelerator(**args)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics