mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Fix the fsdp config cannot work issue. (#37549)
* Fix the fsdp config cannot work issue. Signed-off-by: yuanwu <yuan.wu@intel.com> * Check the fsdp_config type Signed-off-by: yuanwu <yuan.wu@intel.com> * Add the accelerate_fsdp_config test Signed-off-by: yuanwu <yuan.wu@intel.com> * fix error of make style Signed-off-by: yuanwu <yuan.wu@intel.com> * Add key check Signed-off-by: yuanwu <yuan.wu@intel.com> --------- Signed-off-by: yuanwu <yuan.wu@intel.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
816b37010c
commit
a41b6d9b5c
@ -1872,10 +1872,12 @@ class TrainingArguments:
|
||||
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
|
||||
with open(self.fsdp_config, encoding="utf-8") as f:
|
||||
self.fsdp_config = json.load(f)
|
||||
for k in list(self.fsdp_config.keys()):
|
||||
if k.startswith("fsdp_"):
|
||||
v = self.fsdp_config.pop(k)
|
||||
self.fsdp_config[k[5:]] = v
|
||||
|
||||
if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
|
||||
for k in list(self.fsdp_config.keys()):
|
||||
if k.startswith("fsdp_"):
|
||||
v = self.fsdp_config.pop(k)
|
||||
self.fsdp_config[k[5:]] = v
|
||||
|
||||
if self.fsdp_min_num_params > 0:
|
||||
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
|
||||
|
@ -154,6 +154,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
}
|
||||
self.accelerate_fsdp_config = {
|
||||
"fsdp_activation_checkpointing": False,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_backward_prefetch": "BACKWARD_PRE",
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_forward_prefetch": False,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_reshard_after_forward": "FULL_SHARD",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_use_orig_params": True,
|
||||
"fsdp_version": 1,
|
||||
}
|
||||
|
||||
self.fsdp_config = {
|
||||
"backward_prefetch": "BACKWARD_PRE",
|
||||
@ -169,6 +183,28 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||
def test_accelerate_fsdp_config(self, sharding_strategy, dtype):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
kwargs = {
|
||||
"output_dir": output_dir,
|
||||
"train_len": 128,
|
||||
"save_steps": 5,
|
||||
"learning_rate": 0.1,
|
||||
"fsdp": f"{sharding_strategy} offload auto_wrap",
|
||||
"fsdp_config": self.accelerate_fsdp_config,
|
||||
}
|
||||
kwargs[dtype] = True
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
self.assertEqual(trainer.args.fsdp[0], sharding_strategy)
|
||||
self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD)
|
||||
self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP)
|
||||
for k, v in trainer.args.fsdp_config.items():
|
||||
self.assertTrue(k in self.accelerate_fsdp_config)
|
||||
self.assertEqual(v, self.accelerate_fsdp_config[k])
|
||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||
|
||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||
def test_fsdp_config(self, sharding_strategy, dtype):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user