fixing tokenization and training

This commit is contained in:
thomwolf 2019-08-23 17:31:21 +02:00
parent 47d6853439
commit ab7bd5ef98

View File

@ -30,7 +30,7 @@ import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
@ -72,13 +72,8 @@ class TextDataset(Dataset):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
tokenized_text = tokenizer.add_special_tokens_single_sentence(tokenized_text)
while len(tokenized_text) >= block_size: # Truncate in block of block_size
if isinstance(tokenizer, (BertTokenizer, RobertaTokenizer)):
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size - 2]))
tokenized_text = tokenized_text[block_size - 2:]
else:
self.examples.append(tokenized_text[:block_size])
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
tokenized_text = tokenized_text[block_size:]
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
@ -112,15 +107,15 @@ def mask_tokens(inputs, tokenizer, args):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
@ -134,7 +129,7 @@ def train(args, train_dataset, model, tokenizer):
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
@ -329,7 +324,7 @@ def main():
parser.add_argument("--block_size", default=-1, type=int,
help="Optional input sequence length after tokenization."
"The training dataset will be truncated in block of this size for training."
"Default to the model max input length.")
"Default to the model max input length fo single sentences inputs (take into account special tokens).")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
@ -433,7 +428,8 @@ def main():
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
if args.block_size <= 0:
args.block_size = tokenizer.max_len # Our input block size will be the max possible for the model
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
model.to(args.device)