fixing pre-processing bug - averaging loss for gradient accumulation - no_grad on evaluation

This commit is contained in:
thomwolf 2018-11-07 22:12:41 +01:00
parent 1a5bbd83dc
commit 6bb7510a50
2 changed files with 24 additions and 42 deletions

View File

@ -458,7 +458,6 @@ def main():
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
@ -518,20 +517,18 @@ def main():
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step() # We have accumulated enought gradients
model.zero_grad()
@ -579,13 +576,13 @@ def main():
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps #len(eval_dataloader)
eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader)
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': tr_loss/nb_tr_steps}#'loss': loss.item()}
'loss': tr_loss/nb_tr_steps}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:

View File

@ -743,7 +743,7 @@ def main():
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
@ -857,20 +857,13 @@ def main():
model.train()
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
start_positions = start_positions.to(device)
end_positions = start_positions.to(device)
start_positions = start_positions.view(-1, 1)
end_positions = end_positions.view(-1, 1)
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step() # We have accumulated enought gradients
@ -908,30 +901,22 @@ def main():
model.eval()
all_results = []
logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
for idx, i in enumerate(unique_id):
s = [float(x) for x in start_logits[idx]]
e = [float(x) for x in end_logits[idx]]
all_results.append(
RawResult(
unique_id=i,
start_logits=s,
end_logits=e
)
)
with torch.no_grad():
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
for i, example_index in enumerate(example_indices):
start_logits = batch_start_logits[i].detach().cpu().tolist()
end_logits = batch_end_logits[i].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
write_predictions(eval_examples, eval_features, all_results,