mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Save optimizer state, scheduler state and current epoch
This commit is contained in:
parent
0cb163865a
commit
f71b1bb05a
@ -224,7 +224,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for epoch in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
||||||
@ -279,6 +279,10 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
_rotate_checkpoints(args, checkpoint_prefix)
|
_rotate_checkpoints(args, checkpoint_prefix)
|
||||||
|
|
||||||
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||||
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||||
|
torch.save(epoch, os.path.join(output_dir, 'training_state.pt'))
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user