mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
commit
5c0838d846
@ -467,6 +467,6 @@ class BertForQuestionAnswering(nn.Module):
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
return total_loss, (start_logits, end_logits)
|
||||
return total_loss
|
||||
else:
|
||||
return start_logits, end_logits
|
||||
|
@ -458,7 +458,6 @@ def main():
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
|
||||
label_list = processor.get_labels()
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
@ -515,23 +514,21 @@ def main():
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
optimizer.step() # We have accumulated enought gradients
|
||||
model.zero_grad()
|
||||
@ -567,7 +564,8 @@ def main():
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
with torch.no_grad():
|
||||
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
|
||||
logits = logits.detach().cpu().numpy()
|
||||
label_ids = label_ids.to('cpu').numpy()
|
||||
@ -579,13 +577,13 @@ def main():
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
nb_eval_steps += 1
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps #len(eval_dataloader)
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader)
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
'global_step': global_step,
|
||||
'loss': tr_loss/nb_tr_steps}#'loss': loss.item()}
|
||||
'loss': tr_loss/nb_tr_steps}
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
|
24
run_squad.py
24
run_squad.py
@ -743,7 +743,7 @@ def main():
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
@ -855,22 +855,15 @@ def main():
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
start_positions = start_positions.to(device)
|
||||
end_positions = start_positions.to(device)
|
||||
|
||||
start_positions = start_positions.view(-1, 1)
|
||||
end_positions = end_positions.view(-1, 1)
|
||||
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
optimizer.step() # We have accumulated enought gradients
|
||||
@ -911,24 +904,19 @@ def main():
|
||||
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
start_logits = batch_start_logits[i].detach().cpu().tolist()
|
||||
end_logits = batch_end_logits[i].detach().cpu().tolist()
|
||||
|
||||
eval_feature = eval_features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
all_results.append(RawResult(unique_id=unique_id,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||
write_predictions(eval_examples, eval_features, all_results,
|
||||
|
Loading…
Reference in New Issue
Block a user