Fix the fsdp config cannot work issue. (#37549)

* Fix the fsdp config cannot work issue.

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Check the fsdp_config type

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the accelerate_fsdp_config test

Signed-off-by: yuanwu <yuan.wu@intel.com>

* fix error of make style

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add key check

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yuan Wu 2025-04-28 16:44:51 +08:00 committed by GitHub
parent 816b37010c
commit a41b6d9b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 4 deletions

View File

@ -1872,10 +1872,12 @@ class TrainingArguments:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with open(self.fsdp_config, encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
v = self.fsdp_config.pop(k)
self.fsdp_config[k[5:]] = v
if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
v = self.fsdp_config.pop(k)
self.fsdp_config[k[5:]] = v
if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)

View File

@ -154,6 +154,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
}
self.accelerate_fsdp_config = {
"fsdp_activation_checkpointing": False,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_forward_prefetch": False,
"fsdp_offload_params": False,
"fsdp_reshard_after_forward": "FULL_SHARD",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_use_orig_params": True,
"fsdp_version": 1,
}
self.fsdp_config = {
"backward_prefetch": "BACKWARD_PRE",
@ -169,6 +183,28 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
def tearDown(self):
super().tearDown()
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
def test_accelerate_fsdp_config(self, sharding_strategy, dtype):
output_dir = self.get_auto_remove_tmp_dir()
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": f"{sharding_strategy} offload auto_wrap",
"fsdp_config": self.accelerate_fsdp_config,
}
kwargs[dtype] = True
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
self.assertEqual(trainer.args.fsdp[0], sharding_strategy)
self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD)
self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP)
for k, v in trainer.args.fsdp_config.items():
self.assertTrue(k in self.accelerate_fsdp_config)
self.assertEqual(v, self.accelerate_fsdp_config[k])
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
def test_fsdp_config(self, sharding_strategy, dtype):
output_dir = self.get_auto_remove_tmp_dir()