mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixup DeepSpeed things (#34007)
This commit is contained in:
parent
17806d11ba
commit
f4b741d674
@ -229,6 +229,7 @@ if is_peft_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator, skip_first_batches
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
@ -1676,6 +1677,10 @@ class Trainer:
|
||||
self.args.hf_deepspeed_config.trainer_config_process(self.args)
|
||||
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
|
||||
|
||||
# From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
|
||||
# Simply calling `_reset_state` is enough and doesn't need a version pin.
|
||||
AcceleratorState()._reset_state()
|
||||
|
||||
self.create_accelerator_and_postprocess()
|
||||
|
||||
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
||||
|
Loading…
Reference in New Issue
Block a user