typo fix in output tuple

This commit is contained in:
thomwolf 2018-11-07 23:51:12 +01:00
parent d92a7f7721
commit 48d4a5317c

View File

@ -520,7 +520,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss = model(input_ids, segment_ids, input_mask, label_ids)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: