mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix trainer slow tests related to hyperparam search (#24011)
* fix trainer slow tests * commit 2
This commit is contained in:
parent
3c3108972a
commit
460b844360
@ -339,31 +339,7 @@ class Trainer:
|
||||
self.hp_name = None
|
||||
self.is_in_train = False
|
||||
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator(
|
||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
|
||||
# post accelerator creation setup
|
||||
if self.is_fsdp_enabled:
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
|
||||
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if getattr(self.args, "hf_deepspeed_config", None) is None:
|
||||
from transformers.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
ds_plugin = self.accelerator.state.deepspeed_plugin
|
||||
|
||||
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
||||
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
||||
ds_plugin.hf_ds_config.trainer_config_process(self.args)
|
||||
self.create_accelerator_and_postprocess()
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||
@ -1343,7 +1319,8 @@ class Trainer:
|
||||
|
||||
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
|
||||
self.args.hf_deepspeed_config.trainer_config_process(self.args)
|
||||
self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
|
||||
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
|
||||
self.create_accelerator_and_postprocess()
|
||||
|
||||
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
@ -3924,3 +3901,30 @@ class Trainer:
|
||||
if not self.repo.is_repo_clean():
|
||||
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
|
||||
self.repo.git_push()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator(
|
||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
|
||||
# post accelerator creation setup
|
||||
if self.is_fsdp_enabled:
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
|
||||
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if getattr(self.args, "hf_deepspeed_config", None) is None:
|
||||
from transformers.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
ds_plugin = self.accelerator.state.deepspeed_plugin
|
||||
|
||||
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
||||
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
||||
ds_plugin.hf_ds_config.trainer_config_process(self.args)
|
||||
|
Loading…
Reference in New Issue
Block a user