Replace labels with -100 to skip loss calc (#4718)

This commit is contained in:
Setu Shah 2020-06-24 09:14:50 -07:00 committed by GitHub
parent 6894b486d0
commit 0a3d0e02c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,7 +82,9 @@ class DataCollatorForLanguageModeling:
inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "labels": labels}
else:
return {"input_ids": batch, "labels": batch}
labels = batch.clone().detach()
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)