mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Replace labels with -100 to skip loss calc (#4718)
This commit is contained in:
parent
6894b486d0
commit
0a3d0e02c5
@ -82,7 +82,9 @@ class DataCollatorForLanguageModeling:
|
|||||||
inputs, labels = self.mask_tokens(batch)
|
inputs, labels = self.mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
else:
|
else:
|
||||||
return {"input_ids": batch, "labels": batch}
|
labels = batch.clone().detach()
|
||||||
|
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||||
|
return {"input_ids": batch, "labels": labels}
|
||||||
|
|
||||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||||
length_of_first = examples[0].size(0)
|
length_of_first = examples[0].size(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user