Fix FSDP resume Initialization issue (#34032)

* Fix FSDP Initialization for resume training

* Added init_fsdp function to work with dummy values

* Fix FSDP initialization for resuming training

* Added CUDA decorator for tests

* Added torch_gpu decorator to FSDP tests

* Fixup for failing code quality tests
This commit is contained in:
Shikhar Mishra 2024-10-15 17:18:10 +05:30 committed by GitHub
parent 293e6271c6
commit 4de1bdbf63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 0 deletions

View File

@ -273,6 +273,39 @@ def _get_fsdp_ckpt_kwargs():
return {}
def _init_fsdp(model, accelerator, device):
"""
Initialize Fully Sharded Data Parallel (FSDP) for the model.
This function is needed to properly initialize FSDP when resuming from a checkpoint.
It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
See https://github.com/huggingface/transformers/issues/31892 for more details.
Args:
model: The model to initialize with FSDP.
accelerator: The Accelerator object.
device: The device to run the model on.
Returns:
The initialized FSDP model.
"""
model = accelerator.prepare(model)
model.train()
with torch.no_grad():
# Run a forward pass with dummy inputs to initialize FSDP
dummy_input = {
name: torch.ones(
(1, 512),
dtype=torch.long,
device=device,
)
for name in model.forward.__code__.co_varnames
if name != "self"
}
_ = model(**dummy_input)
return model
if TYPE_CHECKING:
import optuna
@ -601,6 +634,10 @@ class Trainer:
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if self.is_fsdp_enabled:
self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):

View File

@ -4914,3 +4914,34 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
param = next(model.parameters())
group = trainer.get_optimizer_group(param)
self.assertIn(param, group["params"])
@require_torch_gpu
@require_torch
@require_accelerate
class TestFSDPInitialization(unittest.TestCase):
def test_fsdp_initialization(self):
config = RegressionModelConfig(a=1, b=1, double_output=False)
model = RegressionPreTrainedModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
fsdp=True,
fsdp_config={"min_num_params": 1},
no_cuda=True,
)
trainer = Trainer(model=model, args=training_args)
# Check for FSDP enabled
self.assertTrue(trainer.is_fsdp_enabled)
# Check if model is wrapped with FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
self.assertTrue(trainer.model, FSDP)
# Running a forward pass to ensure FSDP is initialized
dummy_input = torch.ones((1, 1), dtype=torch.float)
output = trainer.model(dummy_input)
self.assertTrue(output)