Disable delay_optimizer_creation in Trainer to support fsdp2 (#37147)

* github why you do this

* fix

* make fixup

* disable cpu offload test

* fixup

* tmp reworks

* git branch movement

* make fixup

* add require_fsdp_v2_version

* dep issues

* update ruff and fixup
This commit is contained in:
byi8220 2025-04-04 14:11:37 -04:00 committed by GitHub
parent 878562b68d
commit a4e55fcff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 0 deletions

View File

@ -2313,6 +2313,11 @@ class Trainer:
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
if is_fsdp2:
delay_optimizer_creation = False
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None

View File

@ -109,6 +109,15 @@ if is_accelerate_available():
require_fsdp_version = partial(require_fsdp, min_version=FSDP_PYTORCH_VERSION)
FSDP2_ACCELERATE_VERSION = "1.6.0"
require_accelerate_fsdp2 = partial(require_accelerate, min_version=FSDP2_ACCELERATE_VERSION)
require_fsdp_v2_version = require_fsdp
if is_accelerate_available(min_version=FSDP2_ACCELERATE_VERSION):
from accelerate.utils.constants import FSDP2_PYTORCH_VERSION
require_fsdp_v2_version = partial(require_fsdp, min_version=FSDP2_PYTORCH_VERSION)
def get_launcher(distributed=False, use_accelerate=False):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
# - it won't be able to handle that
@ -316,6 +325,73 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
except: # noqa
raise AssertionError("CPU offloading failed with FSDP!")
@require_torch_multi_accelerator
@slow
@require_fsdp
@require_fsdp_v2_version
@require_accelerate_fsdp2
def test_accelerate_fsdp2_integration(self):
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
sharding_strategy = "full_shard"
use_accelerate = True
num_gpus = min(2, backend_device_count(torch_device))
master_port = get_master_port(real_launcher=True)
launcher = f"""accelerate launch
--num_processes {num_gpus}
--main_process_port {master_port}
--use_fsdp
--fsdp_version 2
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP
--fsdp_state_dict_type SHARDED_STATE_DICT
--fsdp_transformer_layer_cls_to_wrap BertLayer""".split()
args = self.get_base_args(output_dir, 2, 25).split()
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
logs = self.run_cmd_and_get_logs(use_accelerate, sharding_strategy, launcher, script, args, output_dir)
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-115")
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self.run_cmd_and_get_logs(
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
)
for log, log1 in zip(logs, logs_resume):
if "learning_rate" in log:
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
@require_torch_multi_accelerator
@slow
@require_fsdp
@require_fsdp_v2_version
@require_accelerate_fsdp2
def test_fsdp2_cpu_offloading(self):
# TODO: This file is missing and should be added or the test should be removed
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
raise unittest.SkipTest("FSDP 2 CPU offloading script not found!")
try:
subprocess.run(
"accelerate launch --fsdp_version 2 utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
shell=True,
check=True,
)
except: # noqa
raise AssertionError("CPU offloading failed with FSDP!")
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
if not use_accelerate:
fsdp_args = [