mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix FSDP resume Initialization issue (#34032)
* Fix FSDP Initialization for resume training * Added init_fsdp function to work with dummy values * Fix FSDP initialization for resuming training * Added CUDA decorator for tests * Added torch_gpu decorator to FSDP tests * Fixup for failing code quality tests
This commit is contained in:
parent
293e6271c6
commit
4de1bdbf63
@ -273,6 +273,39 @@ def _get_fsdp_ckpt_kwargs():
|
||||
return {}
|
||||
|
||||
|
||||
def _init_fsdp(model, accelerator, device):
|
||||
"""
|
||||
Initialize Fully Sharded Data Parallel (FSDP) for the model.
|
||||
|
||||
This function is needed to properly initialize FSDP when resuming from a checkpoint.
|
||||
It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
|
||||
See https://github.com/huggingface/transformers/issues/31892 for more details.
|
||||
|
||||
Args:
|
||||
model: The model to initialize with FSDP.
|
||||
accelerator: The Accelerator object.
|
||||
device: The device to run the model on.
|
||||
|
||||
Returns:
|
||||
The initialized FSDP model.
|
||||
"""
|
||||
model = accelerator.prepare(model)
|
||||
model.train()
|
||||
with torch.no_grad():
|
||||
# Run a forward pass with dummy inputs to initialize FSDP
|
||||
dummy_input = {
|
||||
name: torch.ones(
|
||||
(1, 512),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
for name in model.forward.__code__.co_varnames
|
||||
if name != "self"
|
||||
}
|
||||
_ = model(**dummy_input)
|
||||
return model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
@ -601,6 +634,10 @@ class Trainer:
|
||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||
)
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
|
||||
|
||||
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
|
||||
self.optimizer is not None or self.lr_scheduler is not None
|
||||
):
|
||||
|
@ -4914,3 +4914,34 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
class TestFSDPInitialization(unittest.TestCase):
|
||||
def test_fsdp_initialization(self):
|
||||
config = RegressionModelConfig(a=1, b=1, double_output=False)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
fsdp=True,
|
||||
fsdp_config={"min_num_params": 1},
|
||||
no_cuda=True,
|
||||
)
|
||||
trainer = Trainer(model=model, args=training_args)
|
||||
|
||||
# Check for FSDP enabled
|
||||
self.assertTrue(trainer.is_fsdp_enabled)
|
||||
|
||||
# Check if model is wrapped with FSDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
self.assertTrue(trainer.model, FSDP)
|
||||
|
||||
# Running a forward pass to ensure FSDP is initialized
|
||||
dummy_input = torch.ones((1, 1), dtype=torch.float)
|
||||
output = trainer.model(dummy_input)
|
||||
self.assertTrue(output)
|
||||
|
Loading…
Reference in New Issue
Block a user