mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Don't modify labels inplace in LabelSmoother
(#13464)
This commit is contained in:
parent
c164c651dc
commit
cd66539662
@ -458,7 +458,7 @@ class LabelSmoother:
|
|||||||
padding_mask = labels.eq(self.ignore_index)
|
padding_mask = labels.eq(self.ignore_index)
|
||||||
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
||||||
# will ignore them in any case.
|
# will ignore them in any case.
|
||||||
labels.clamp_min_(0)
|
labels = torch.clamp(labels, min=0)
|
||||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||||
# works for fp16 input tensor too, by internally upcasting it to fp32
|
# works for fp16 input tensor too, by internally upcasting it to fp32
|
||||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
|
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user