From ada24def22199459d8c1decc311dfe8dae7a7d8c Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 5 Feb 2020 12:49:18 -0500 Subject: [PATCH] [run_lm_finetuning] Tweak fix for non-long tensor, close #2728 see 1ebfeb79469d544a2bd817aa32c77e0514485ff9 and #2728 Co-Authored-By: Lysandre Debut --- examples/run_lm_finetuning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 99f961e4218..00e9e2f1234 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -118,7 +118,7 @@ class TextDataset(Dataset): return len(self.examples) def __getitem__(self, item): - return torch.tensor(self.examples[item]) + return torch.tensor(self.examples[item], dtype=torch.long) class LineByLineTextDataset(Dataset): @@ -138,7 +138,7 @@ class LineByLineTextDataset(Dataset): return len(self.examples) def __getitem__(self, i): - return torch.tensor(self.examples[i]) + return torch.tensor(self.examples[i], dtype=torch.long) def load_and_cache_examples(args, tokenizer, evaluate=False): @@ -195,7 +195,6 @@ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) - def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ - inputs = inputs.clone().type(dtype=torch.long) labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, args.mlm_probability)