diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 31110f3b6c8..c20377f7091 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -433,7 +433,8 @@ class LabelSmoother: # will ignore them in any case. labels.clamp_min_(0) nll_loss = log_probs.gather(dim=-1, index=labels) - smoothed_loss = log_probs.sum(dim=-1, keepdim=True) + # works for fp16 input tensor too, by internally upcasting it to fp32 + smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) nll_loss.masked_fill_(padding_mask, 0.0) smoothed_loss.masked_fill_(padding_mask, 0.0)