mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix FSDP resume Initialization issue (#34032)
* Fix FSDP Initialization for resume training * Added init_fsdp function to work with dummy values * Fix FSDP initialization for resuming training * Added CUDA decorator for tests * Added torch_gpu decorator to FSDP tests * Fixup for failing code quality tests
This commit is contained in:
parent
293e6271c6
commit
4de1bdbf63
@ -273,6 +273,39 @@ def _get_fsdp_ckpt_kwargs():
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _init_fsdp(model, accelerator, device):
|
||||||
|
"""
|
||||||
|
Initialize Fully Sharded Data Parallel (FSDP) for the model.
|
||||||
|
|
||||||
|
This function is needed to properly initialize FSDP when resuming from a checkpoint.
|
||||||
|
It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
|
||||||
|
See https://github.com/huggingface/transformers/issues/31892 for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to initialize with FSDP.
|
||||||
|
accelerator: The Accelerator object.
|
||||||
|
device: The device to run the model on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The initialized FSDP model.
|
||||||
|
"""
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
model.train()
|
||||||
|
with torch.no_grad():
|
||||||
|
# Run a forward pass with dummy inputs to initialize FSDP
|
||||||
|
dummy_input = {
|
||||||
|
name: torch.ones(
|
||||||
|
(1, 512),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for name in model.forward.__code__.co_varnames
|
||||||
|
if name != "self"
|
||||||
|
}
|
||||||
|
_ = model(**dummy_input)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
@ -601,6 +634,10 @@ class Trainer:
|
|||||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.is_fsdp_enabled:
|
||||||
|
self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
|
||||||
|
|
||||||
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
|
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
|
||||||
self.optimizer is not None or self.lr_scheduler is not None
|
self.optimizer is not None or self.lr_scheduler is not None
|
||||||
):
|
):
|
||||||
|
@ -4914,3 +4914,34 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
|||||||
param = next(model.parameters())
|
param = next(model.parameters())
|
||||||
group = trainer.get_optimizer_group(param)
|
group = trainer.get_optimizer_group(param)
|
||||||
self.assertIn(param, group["params"])
|
self.assertIn(param, group["params"])
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch
|
||||||
|
@require_accelerate
|
||||||
|
class TestFSDPInitialization(unittest.TestCase):
|
||||||
|
def test_fsdp_initialization(self):
|
||||||
|
config = RegressionModelConfig(a=1, b=1, double_output=False)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
fsdp=True,
|
||||||
|
fsdp_config={"min_num_params": 1},
|
||||||
|
no_cuda=True,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=training_args)
|
||||||
|
|
||||||
|
# Check for FSDP enabled
|
||||||
|
self.assertTrue(trainer.is_fsdp_enabled)
|
||||||
|
|
||||||
|
# Check if model is wrapped with FSDP
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
|
||||||
|
self.assertTrue(trainer.model, FSDP)
|
||||||
|
|
||||||
|
# Running a forward pass to ensure FSDP is initialized
|
||||||
|
dummy_input = torch.ones((1, 1), dtype=torch.float)
|
||||||
|
output = trainer.model(dummy_input)
|
||||||
|
self.assertTrue(output)
|
||||||
|
Loading…
Reference in New Issue
Block a user