mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add functionality to continue training from last saved global_step
This commit is contained in:
parent
2d73591a18
commit
9626e0458c
@ -223,17 +223,37 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
epochs_trained = 0
|
||||||
|
steps_trained_in_current_epoch = 0
|
||||||
|
# Check if continuing training from a checkpoint
|
||||||
|
if os.path.exists(args.model_name_or_path):
|
||||||
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
|
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||||
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
|
logger.info(" Continuing training from global step %d", global_step)
|
||||||
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
|
|
||||||
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
model_to_resize.resize_token_embeddings(len(tokenizer))
|
model_to_resize.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||||
for epoch in train_iterator:
|
for epoch in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
|
# Skip past any already trained steps if resuming training
|
||||||
|
if steps_trained_in_current_epoch > 0:
|
||||||
|
steps_trained_in_current_epoch -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
||||||
inputs = inputs.to(args.device)
|
inputs = inputs.to(args.device)
|
||||||
labels = labels.to(args.device)
|
labels = labels.to(args.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user