From e21f89f64c0683f111572b8b9fa38ffff64885f1 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 22 Mar 2021 19:23:24 -0700 Subject: [PATCH] fix nan in full-fp16 label_smoothing eval (#10815) --- src/transformers/trainer_pt_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 31110f3b6c8..c20377f7091 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -433,7 +433,8 @@ class LabelSmoother: # will ignore them in any case. labels.clamp_min_(0) nll_loss = log_probs.gather(dim=-1, index=labels) - smoothed_loss = log_probs.sum(dim=-1, keepdim=True) + # works for fp16 input tensor too, by internally upcasting it to fp32 + smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) nll_loss.masked_fill_(padding_mask, 0.0) smoothed_loss.masked_fill_(padding_mask, 0.0)