From a41b6d9b5c39e002670c1d852e822d0de4aca39b Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Mon, 28 Apr 2025 16:44:51 +0800 Subject: [PATCH] Fix the fsdp config cannot work issue. (#37549) * Fix the fsdp config cannot work issue. Signed-off-by: yuanwu * Check the fsdp_config type Signed-off-by: yuanwu * Add the accelerate_fsdp_config test Signed-off-by: yuanwu * fix error of make style Signed-off-by: yuanwu * Add key check Signed-off-by: yuanwu --------- Signed-off-by: yuanwu Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/training_args.py | 10 +++++---- tests/fsdp/test_fsdp.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 65cb2705cba..3b1b8c58b5d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1872,10 +1872,12 @@ class TrainingArguments: warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") with open(self.fsdp_config, encoding="utf-8") as f: self.fsdp_config = json.load(f) - for k in list(self.fsdp_config.keys()): - if k.startswith("fsdp_"): - v = self.fsdp_config.pop(k) - self.fsdp_config[k[5:]] = v + + if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): + for k in list(self.fsdp_config.keys()): + if k.startswith("fsdp_"): + v = self.fsdp_config.pop(k) + self.fsdp_config[k[5:]] = v if self.fsdp_min_num_params > 0: warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 96c1860860f..781199747f3 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -154,6 +154,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): "LOCAL_RANK": "0", "WORLD_SIZE": "1", } + self.accelerate_fsdp_config = { + "fsdp_activation_checkpointing": False, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch": "BACKWARD_PRE", + "fsdp_cpu_ram_efficient_loading": True, + "fsdp_forward_prefetch": False, + "fsdp_offload_params": False, + "fsdp_reshard_after_forward": "FULL_SHARD", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_sync_module_states": True, + "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "fsdp_use_orig_params": True, + "fsdp_version": 1, + } self.fsdp_config = { "backward_prefetch": "BACKWARD_PRE", @@ -169,6 +183,28 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): def tearDown(self): super().tearDown() + @parameterized.expand(params, name_func=_parameterized_custom_name_func) + def test_accelerate_fsdp_config(self, sharding_strategy, dtype): + output_dir = self.get_auto_remove_tmp_dir() + kwargs = { + "output_dir": output_dir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "fsdp": f"{sharding_strategy} offload auto_wrap", + "fsdp_config": self.accelerate_fsdp_config, + } + kwargs[dtype] = True + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(**kwargs) + self.assertEqual(trainer.args.fsdp[0], sharding_strategy) + self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD) + self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP) + for k, v in trainer.args.fsdp_config.items(): + self.assertTrue(k in self.accelerate_fsdp_config) + self.assertEqual(v, self.accelerate_fsdp_config[k]) + self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") + @parameterized.expand(params, name_func=_parameterized_custom_name_func) def test_fsdp_config(self, sharding_strategy, dtype): output_dir = self.get_auto_remove_tmp_dir()