Merge pull request #797 from yzy5630/fix-examples

fix some errors for distributed lm_finetuning
This commit is contained in:
Thomas Wolf 2019-07-18 23:32:33 +02:00 committed by GitHub
commit a615499076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 54 deletions

View File

@ -155,11 +155,14 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument("--warmup_proportion", parser.add_argument("--warmup_steps",
default=0.1, default=0,
type=int,
help="Linear warmup over warmup_steps.")
parser.add_argument("--adam_epsilon",
default=1e-8,
type=float, type=float,
help="Proportion of training to perform linear learning rate warmup for. " help="Epsilon for Adam optimizer.")
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--learning_rate", parser.add_argument("--learning_rate",
default=3e-5, default=3e-5,
type=float, type=float,
@ -270,13 +273,9 @@ def main():
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else: else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else: else:
optimizer = AdamW(optimizer_grouped_parameters, optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
lr=args.learning_rate, scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
logging.info("***** Running training *****") logging.info("***** Running training *****")
@ -298,7 +297,8 @@ def main():
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
loss = outputs[0]
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
@ -314,26 +314,16 @@ def main():
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: scheduler.step() # Update learning rate schedule
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model # Save a trained model
logging.info("** ** * Saving fine-tuned model ** ** * ") if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 :
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self logging.info("** ** * Saving fine-tuned model ** ** * ")
model.save_pretrained(args.output_dir)
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) tokenizer.save_pretrained(args.output_dir)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -32,7 +32,7 @@ from tqdm import tqdm, trange
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForPreTraining from pytorch_transformers.modeling_bert import BertForPreTraining
from pytorch_transformers.tokenization_bert import BertTokenizer from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
@ -434,15 +434,18 @@ def main():
default=3e-5, default=3e-5,
type=float, type=float,
help="The initial learning rate for Adam.") help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", parser.add_argument("--num_train_epochs",
default=3.0, default=3.0,
type=float, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", parser.add_argument("--warmup_steps",
default=0.1, default=0,
type=float, type=int,
help="Proportion of training to perform linear learning rate warmup for. " help="Linear warmup over warmup_steps.")
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", parser.add_argument("--no_cuda",
action='store_true', action='store_true',
help="Whether not to use CUDA when available") help="Whether not to use CUDA when available")
@ -504,7 +507,7 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
@ -558,14 +561,10 @@ def main():
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else: else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else: else:
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
lr=args.learning_rate, scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
@ -589,7 +588,8 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
loss = outputs[0]
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
@ -602,25 +602,16 @@ def main():
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: scheduler.step() # Update learning rate schedule
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model # Save a trained model
logger.info("** ** * Saving fine - tuned model ** ** * ") if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1):
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self logger.info("** ** * Saving fine - tuned model ** ** * ")
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) model.save_pretrained(args.output_dir)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME) tokenizer.save_pretrained(args.output_dir)
if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):