Don't modify labels inplace in LabelSmoother (#13464)

This commit is contained in:
Sylvain Gugger 2021-09-08 07:45:36 -04:00 committed by GitHub
parent c164c651dc
commit cd66539662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -458,7 +458,7 @@ class LabelSmoother:
padding_mask = labels.eq(self.ignore_index)
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
# will ignore them in any case.
labels.clamp_min_(0)
labels = torch.clamp(labels, min=0)
nll_loss = log_probs.gather(dim=-1, index=labels)
# works for fp16 input tensor too, by internally upcasting it to fp32
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)