fix nan in full-fp16 label_smoothing eval (#10815)

This commit is contained in:
Stas Bekman 2021-03-22 19:23:24 -07:00 committed by GitHub
parent b5b957a65c
commit e21f89f64c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -433,7 +433,8 @@ class LabelSmoother:
# will ignore them in any case.
labels.clamp_min_(0)
nll_loss = log_probs.gather(dim=-1, index=labels)
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
# works for fp16 input tensor too, by internally upcasting it to fp32
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)