fix trainer slow tests related to hyperparam search (#24011)

* fix trainer slow tests

* commit 2
This commit is contained in:
Sourab Mangrulkar 2023-06-05 17:58:10 +05:30 committed by GitHub
parent 3c3108972a
commit 460b844360
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -339,31 +339,7 @@ class Trainer:
self.hp_name = None
self.is_in_train = False
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
self.create_accelerator_and_postprocess()
# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
@ -1343,7 +1319,8 @@ class Trainer:
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
@ -3924,3 +3901,30 @@ class Trainer:
if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push()
def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)