mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
78d6484675
commit
31bb662db1
@ -2445,7 +2445,11 @@ class Trainer:
|
||||
)
|
||||
|
||||
# Update the references
|
||||
self.state.init_training_references(self, train_dataloader, max_steps, num_train_epochs, trial)
|
||||
for attr in ("model", "optimizer", "lr_scheduler"):
|
||||
setattr(self.callback_handler, attr, getattr(self, attr))
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
|
||||
self.state.init_training_references(self, max_steps, num_train_epochs, trial)
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(args.device)
|
||||
|
@ -164,14 +164,10 @@ class TrainerState:
|
||||
num_steps = math.ceil(max_steps * num_steps)
|
||||
setattr(self, f"{step_kind}_steps", num_steps)
|
||||
|
||||
def init_training_references(self, trainer, train_dataloader, max_steps, num_train_epochs, trial):
|
||||
def init_training_references(self, trainer, max_steps, num_train_epochs, trial):
|
||||
"""
|
||||
Stores the initial training references needed in `self`
|
||||
"""
|
||||
for attr in ("model", "optimizer", "lr_scheduler"):
|
||||
setattr(self, attr, getattr(trainer, attr))
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
if trainer.hp_name is not None and trainer._trial is not None:
|
||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
||||
# parameter to Train when using DDP.
|
||||
|
Loading…
Reference in New Issue
Block a user