mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix nan in full-fp16 label_smoothing eval (#10815)
This commit is contained in:
parent
b5b957a65c
commit
e21f89f64c
@ -433,7 +433,8 @@ class LabelSmoother:
|
||||
# will ignore them in any case.
|
||||
labels.clamp_min_(0)
|
||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
|
||||
# works for fp16 input tensor too, by internally upcasting it to fp32
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
|
||||
nll_loss.masked_fill_(padding_mask, 0.0)
|
||||
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
||||
|
Loading…
Reference in New Issue
Block a user