mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix the fsdp config cannot work issue. (#37549)
* Fix the fsdp config cannot work issue. Signed-off-by: yuanwu <yuan.wu@intel.com> * Check the fsdp_config type Signed-off-by: yuanwu <yuan.wu@intel.com> * Add the accelerate_fsdp_config test Signed-off-by: yuanwu <yuan.wu@intel.com> * fix error of make style Signed-off-by: yuanwu <yuan.wu@intel.com> * Add key check Signed-off-by: yuanwu <yuan.wu@intel.com> --------- Signed-off-by: yuanwu <yuan.wu@intel.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
816b37010c
commit
a41b6d9b5c
@ -1872,6 +1872,8 @@ class TrainingArguments:
|
|||||||
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
|
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
|
||||||
with open(self.fsdp_config, encoding="utf-8") as f:
|
with open(self.fsdp_config, encoding="utf-8") as f:
|
||||||
self.fsdp_config = json.load(f)
|
self.fsdp_config = json.load(f)
|
||||||
|
|
||||||
|
if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
|
||||||
for k in list(self.fsdp_config.keys()):
|
for k in list(self.fsdp_config.keys()):
|
||||||
if k.startswith("fsdp_"):
|
if k.startswith("fsdp_"):
|
||||||
v = self.fsdp_config.pop(k)
|
v = self.fsdp_config.pop(k)
|
||||||
|
@ -154,6 +154,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
"LOCAL_RANK": "0",
|
"LOCAL_RANK": "0",
|
||||||
"WORLD_SIZE": "1",
|
"WORLD_SIZE": "1",
|
||||||
}
|
}
|
||||||
|
self.accelerate_fsdp_config = {
|
||||||
|
"fsdp_activation_checkpointing": False,
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"fsdp_backward_prefetch": "BACKWARD_PRE",
|
||||||
|
"fsdp_cpu_ram_efficient_loading": True,
|
||||||
|
"fsdp_forward_prefetch": False,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_reshard_after_forward": "FULL_SHARD",
|
||||||
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_use_orig_params": True,
|
||||||
|
"fsdp_version": 1,
|
||||||
|
}
|
||||||
|
|
||||||
self.fsdp_config = {
|
self.fsdp_config = {
|
||||||
"backward_prefetch": "BACKWARD_PRE",
|
"backward_prefetch": "BACKWARD_PRE",
|
||||||
@ -169,6 +183,28 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
|
def test_accelerate_fsdp_config(self, sharding_strategy, dtype):
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
kwargs = {
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"train_len": 128,
|
||||||
|
"save_steps": 5,
|
||||||
|
"learning_rate": 0.1,
|
||||||
|
"fsdp": f"{sharding_strategy} offload auto_wrap",
|
||||||
|
"fsdp_config": self.accelerate_fsdp_config,
|
||||||
|
}
|
||||||
|
kwargs[dtype] = True
|
||||||
|
with mockenv_context(**self.dist_env_1_gpu):
|
||||||
|
trainer = get_regression_trainer(**kwargs)
|
||||||
|
self.assertEqual(trainer.args.fsdp[0], sharding_strategy)
|
||||||
|
self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD)
|
||||||
|
self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP)
|
||||||
|
for k, v in trainer.args.fsdp_config.items():
|
||||||
|
self.assertTrue(k in self.accelerate_fsdp_config)
|
||||||
|
self.assertEqual(v, self.accelerate_fsdp_config[k])
|
||||||
|
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
def test_fsdp_config(self, sharding_strategy, dtype):
|
def test_fsdp_config(self, sharding_strategy, dtype):
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
Loading…
Reference in New Issue
Block a user