[run_lm_finetuning] Tweak fix for non-long tensor, close #2728

see 1ebfeb7946 and #2728

Co-Authored-By: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Julien Chaumond 2020-02-05 12:49:18 -05:00
parent 2184f87003
commit ada24def22

View File

@ -118,7 +118,7 @@ class TextDataset(Dataset):
return len(self.examples) return len(self.examples)
def __getitem__(self, item): def __getitem__(self, item):
return torch.tensor(self.examples[item]) return torch.tensor(self.examples[item], dtype=torch.long)
class LineByLineTextDataset(Dataset): class LineByLineTextDataset(Dataset):
@ -138,7 +138,7 @@ class LineByLineTextDataset(Dataset):
return len(self.examples) return len(self.examples)
def __getitem__(self, i): def __getitem__(self, i):
return torch.tensor(self.examples[i]) return torch.tensor(self.examples[i], dtype=torch.long)
def load_and_cache_examples(args, tokenizer, evaluate=False): def load_and_cache_examples(args, tokenizer, evaluate=False):
@ -195,7 +195,6 @@ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
inputs = inputs.clone().type(dtype=torch.long)
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, args.mlm_probability) probability_matrix = torch.full(labels.shape, args.mlm_probability)