fixed the bug raised by "tmp_eval_loss += tmp_eval_loss.item()" when parallelly using multi-gpu.

This commit is contained in:
focox@qq.com 2019-10-23 20:27:13 +08:00
parent ef1b8b2ae5
commit bd847ce7d7

View File

@ -210,6 +210,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
if args.n_gpu > 1:
tmp_eval_loss = tmp_eval_loss.mean() # mean() to average on multi-gpu parallel evaluating
eval_loss += tmp_eval_loss.item()
nb_eval_steps += 1
if preds is None: