update run_glue

This commit is contained in:
thomwolf 2019-07-09 17:00:32 +02:00
parent 3b7cb7bf44
commit 4ce237c880

View File

@ -69,7 +69,11 @@ def train(args, train_dataset, model):
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.max_steps > 0:
num_train_optimization_steps = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
param_optimizer = list(model.named_parameters())
@ -91,10 +95,8 @@ def train(args, train_dataset, model):
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
# Train!
logger.info("***** Running training *****")
@ -113,7 +115,7 @@ def train(args, train_dataset, model):
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0]
@ -140,14 +142,16 @@ def train(args, train_dataset, model):
if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.max_steps > 0 and global_step > args.max_steps:
break
if args.max_steps > 0 and global_step > args.max_steps:
break
return global_step, tr_loss / global_step
def evalutate(args, eval_task, eval_output_dir, dataset, model):
""" Evaluate the model """
if os.path.exists(eval_output_dir) and os.listdir(eval_output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(eval_output_dir))
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
@ -166,13 +170,13 @@ def evalutate(args, eval_task, eval_output_dir, dataset, model):
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.no_grad():
outputs = model(input_ids,
token_type_ids=segment_ids,
attention_mask=input_mask,
labels=label_ids)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
@ -276,6 +280,8 @@ def main():
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
parser.add_argument("--no_cuda", action='store_true',
@ -299,6 +305,9 @@ def main():
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
@ -320,8 +329,8 @@ def main():
# Setup logging
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Setup seeds
random.seed(args.seed)
@ -375,8 +384,6 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)