Fixup DeepSpeed things (#34007)

This commit is contained in:
Zach Mueller 2024-10-08 09:04:24 -04:00 committed by GitHub
parent 17806d11ba
commit f4b741d674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -229,6 +229,7 @@ if is_peft_available():
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
@ -1676,6 +1677,10 @@ class Trainer:
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
# From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
# Simply calling `_reset_state` is enough and doesn't need a version pin.
AcceleratorState()._reset_state()
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):