mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fixing tokenization and training
This commit is contained in:
parent
47d6853439
commit
ab7bd5ef98
@ -30,7 +30,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@ -72,14 +72,9 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
|
||||||
tokenized_text = tokenizer.add_special_tokens_single_sentence(tokenized_text)
|
|
||||||
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
||||||
if isinstance(tokenizer, (BertTokenizer, RobertaTokenizer)):
|
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
|
||||||
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size - 2]))
|
tokenized_text = tokenized_text[block_size:]
|
||||||
tokenized_text = tokenized_text[block_size - 2:]
|
|
||||||
else:
|
|
||||||
self.examples.append(tokenized_text[:block_size])
|
|
||||||
tokenized_text = tokenized_text[block_size:]
|
|
||||||
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
||||||
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
||||||
# can change this behavior by adding (model specific) padding.
|
# can change this behavior by adding (model specific) padding.
|
||||||
@ -112,15 +107,15 @@ def mask_tokens(inputs, tokenizer, args):
|
|||||||
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
||||||
labels = inputs.clone()
|
labels = inputs.clone()
|
||||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool()
|
||||||
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
||||||
|
|
||||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
||||||
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
||||||
|
|
||||||
# 10% of the time, we replace masked input tokens with random word
|
# 10% of the time, we replace masked input tokens with random word
|
||||||
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
||||||
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
||||||
inputs[indices_random] = random_words[indices_random]
|
inputs[indices_random] = random_words[indices_random]
|
||||||
|
|
||||||
@ -134,7 +129,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
|
|
||||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||||
train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
@ -329,7 +324,7 @@ def main():
|
|||||||
parser.add_argument("--block_size", default=-1, type=int,
|
parser.add_argument("--block_size", default=-1, type=int,
|
||||||
help="Optional input sequence length after tokenization."
|
help="Optional input sequence length after tokenization."
|
||||||
"The training dataset will be truncated in block of this size for training."
|
"The training dataset will be truncated in block of this size for training."
|
||||||
"Default to the model max input length.")
|
"Default to the model max input length fo single sentences inputs (take into account special tokens).")
|
||||||
parser.add_argument("--do_train", action='store_true',
|
parser.add_argument("--do_train", action='store_true',
|
||||||
help="Whether to run training.")
|
help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action='store_true',
|
||||||
@ -433,7 +428,8 @@ def main():
|
|||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||||
if args.block_size <= 0:
|
if args.block_size <= 0:
|
||||||
args.block_size = tokenizer.max_len # Our input block size will be the max possible for the model
|
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
|
||||||
|
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user