diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 65cb2705cba..3b1b8c58b5d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1872,10 +1872,12 @@ class TrainingArguments: warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") with open(self.fsdp_config, encoding="utf-8") as f: self.fsdp_config = json.load(f) - for k in list(self.fsdp_config.keys()): - if k.startswith("fsdp_"): - v = self.fsdp_config.pop(k) - self.fsdp_config[k[5:]] = v + + if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): + for k in list(self.fsdp_config.keys()): + if k.startswith("fsdp_"): + v = self.fsdp_config.pop(k) + self.fsdp_config[k[5:]] = v if self.fsdp_min_num_params > 0: warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 96c1860860f..781199747f3 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -154,6 +154,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): "LOCAL_RANK": "0", "WORLD_SIZE": "1", } + self.accelerate_fsdp_config = { + "fsdp_activation_checkpointing": False, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch": "BACKWARD_PRE", + "fsdp_cpu_ram_efficient_loading": True, + "fsdp_forward_prefetch": False, + "fsdp_offload_params": False, + "fsdp_reshard_after_forward": "FULL_SHARD", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_sync_module_states": True, + "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "fsdp_use_orig_params": True, + "fsdp_version": 1, + } self.fsdp_config = { "backward_prefetch": "BACKWARD_PRE", @@ -169,6 +183,28 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): def tearDown(self): super().tearDown() + @parameterized.expand(params, name_func=_parameterized_custom_name_func) + def test_accelerate_fsdp_config(self, sharding_strategy, dtype): + output_dir = self.get_auto_remove_tmp_dir() + kwargs = { + "output_dir": output_dir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "fsdp": f"{sharding_strategy} offload auto_wrap", + "fsdp_config": self.accelerate_fsdp_config, + } + kwargs[dtype] = True + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(**kwargs) + self.assertEqual(trainer.args.fsdp[0], sharding_strategy) + self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD) + self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP) + for k, v in trainer.args.fsdp_config.items(): + self.assertTrue(k in self.accelerate_fsdp_config) + self.assertEqual(v, self.accelerate_fsdp_config[k]) + self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") + @parameterized.expand(params, name_func=_parameterized_custom_name_func) def test_fsdp_config(self, sharding_strategy, dtype): output_dir = self.get_auto_remove_tmp_dir()