mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Disable delay_optimizer_creation in Trainer
to support fsdp2 (#37147)
* github why you do this * fix * make fixup * disable cpu offload test * fixup * tmp reworks * git branch movement * make fixup * add require_fsdp_v2_version * dep issues * update ruff and fixup
This commit is contained in:
parent
878562b68d
commit
a4e55fcff8
@ -2313,6 +2313,11 @@ class Trainer:
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
|
||||
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
|
||||
if is_fsdp2:
|
||||
delay_optimizer_creation = False
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
self.lr_scheduler = None
|
||||
|
@ -109,6 +109,15 @@ if is_accelerate_available():
|
||||
require_fsdp_version = partial(require_fsdp, min_version=FSDP_PYTORCH_VERSION)
|
||||
|
||||
|
||||
FSDP2_ACCELERATE_VERSION = "1.6.0"
|
||||
require_accelerate_fsdp2 = partial(require_accelerate, min_version=FSDP2_ACCELERATE_VERSION)
|
||||
require_fsdp_v2_version = require_fsdp
|
||||
if is_accelerate_available(min_version=FSDP2_ACCELERATE_VERSION):
|
||||
from accelerate.utils.constants import FSDP2_PYTORCH_VERSION
|
||||
|
||||
require_fsdp_v2_version = partial(require_fsdp, min_version=FSDP2_PYTORCH_VERSION)
|
||||
|
||||
|
||||
def get_launcher(distributed=False, use_accelerate=False):
|
||||
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
|
||||
# - it won't be able to handle that
|
||||
@ -316,6 +325,73 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
except: # noqa
|
||||
raise AssertionError("CPU offloading failed with FSDP!")
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
@require_fsdp
|
||||
@require_fsdp_v2_version
|
||||
@require_accelerate_fsdp2
|
||||
def test_accelerate_fsdp2_integration(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||
sharding_strategy = "full_shard"
|
||||
use_accelerate = True
|
||||
|
||||
num_gpus = min(2, backend_device_count(torch_device))
|
||||
master_port = get_master_port(real_launcher=True)
|
||||
launcher = f"""accelerate launch
|
||||
--num_processes {num_gpus}
|
||||
--main_process_port {master_port}
|
||||
--use_fsdp
|
||||
--fsdp_version 2
|
||||
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP
|
||||
--fsdp_state_dict_type SHARDED_STATE_DICT
|
||||
--fsdp_transformer_layer_cls_to_wrap BertLayer""".split()
|
||||
args = self.get_base_args(output_dir, 2, 25).split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
|
||||
logs = self.run_cmd_and_get_logs(use_accelerate, sharding_strategy, launcher, script, args, output_dir)
|
||||
|
||||
# resume from ckpt
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-115")
|
||||
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
|
||||
|
||||
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
|
||||
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
|
||||
any(
|
||||
FSDP_MODEL_NAME in folder_name
|
||||
for folder_name in os.listdir(checkpoint)
|
||||
if os.path.isdir(os.path.join(checkpoint, folder_name))
|
||||
)
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
self.assertTrue(is_fsdp_ckpt)
|
||||
|
||||
logs_resume = self.run_cmd_and_get_logs(
|
||||
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
|
||||
)
|
||||
|
||||
for log, log1 in zip(logs, logs_resume):
|
||||
if "learning_rate" in log:
|
||||
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
@require_fsdp
|
||||
@require_fsdp_v2_version
|
||||
@require_accelerate_fsdp2
|
||||
def test_fsdp2_cpu_offloading(self):
|
||||
# TODO: This file is missing and should be added or the test should be removed
|
||||
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
|
||||
raise unittest.SkipTest("FSDP 2 CPU offloading script not found!")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
"accelerate launch --fsdp_version 2 utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except: # noqa
|
||||
raise AssertionError("CPU offloading failed with FSDP!")
|
||||
|
||||
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
|
||||
if not use_accelerate:
|
||||
fsdp_args = [
|
||||
|
Loading…
Reference in New Issue
Block a user