Fix callback handler reference (#36250)

* fix reference

* style
This commit is contained in:
Marc Sun 2025-02-19 18:17:33 +01:00 committed by GitHub
parent 78d6484675
commit 31bb662db1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 6 deletions

View File

@ -2445,7 +2445,11 @@ class Trainer:
)
# Update the references
self.state.init_training_references(self, train_dataloader, max_steps, num_train_epochs, trial)
for attr in ("model", "optimizer", "lr_scheduler"):
setattr(self.callback_handler, attr, getattr(self, attr))
self.callback_handler.train_dataloader = train_dataloader
self.state.init_training_references(self, max_steps, num_train_epochs, trial)
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)

View File

@ -164,14 +164,10 @@ class TrainerState:
num_steps = math.ceil(max_steps * num_steps)
setattr(self, f"{step_kind}_steps", num_steps)
def init_training_references(self, trainer, train_dataloader, max_steps, num_train_epochs, trial):
def init_training_references(self, trainer, max_steps, num_train_epochs, trial):
"""
Stores the initial training references needed in `self`
"""
for attr in ("model", "optimizer", "lr_scheduler"):
setattr(self, attr, getattr(trainer, attr))
self.train_dataloader = train_dataloader
if trainer.hp_name is not None and trainer._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.