mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add exemplary training data. update to nvidia apex. refactor 'item -> line in doc' mapping. add warning for unknown word.
This commit is contained in:
parent
17595ef2de
commit
e5fc98c542
@ -498,8 +498,8 @@ loss = 0.06423990014260186
|
||||
#### LM Fine-tuning
|
||||
|
||||
The data should be a text file in the same format as [sample_text.txt](./samples/sample_text.txt) (one sentence per line, docs separated by empty line).
|
||||
|
||||
Training one epoch on a 500k sentence corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
|
||||
You can download an [exemplary training corpus](https://ext-bert-sample.obs.eu-de.otc.t-systems.com/small_wiki_sentence_corpus.txt) generated from wikipedia articles and splitted into ~500k sentences with spaCy.
|
||||
Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
|
||||
|
||||
|
||||
```shell
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -41,6 +42,12 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0 - x
|
||||
|
||||
|
||||
class BERTDataset(Dataset):
|
||||
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
|
||||
self.vocab = tokenizer.vocab
|
||||
@ -59,6 +66,7 @@ class BERTDataset(Dataset):
|
||||
# for loading samples in memory
|
||||
self.current_random_doc = 0
|
||||
self.num_docs = 0
|
||||
self.sample_to_doc = [] # map sample index to doc and line
|
||||
|
||||
# load samples into memory
|
||||
if on_memory:
|
||||
@ -71,12 +79,20 @@ class BERTDataset(Dataset):
|
||||
if line == "":
|
||||
self.all_docs.append(doc)
|
||||
doc = []
|
||||
#remove last added sample because there won't be a subsequent line anymore in the doc
|
||||
self.sample_to_doc.pop()
|
||||
else:
|
||||
#store as one sample
|
||||
sample = {"doc_id": len(self.all_docs),
|
||||
"line": len(doc)}
|
||||
self.sample_to_doc.append(sample)
|
||||
doc.append(line)
|
||||
self.corpus_lines = self.corpus_lines + 1
|
||||
self.corpus_lines = self.corpus_lines + 1
|
||||
|
||||
# if last row in file is not empty
|
||||
if self.all_docs[-1] != doc:
|
||||
self.all_docs.append(doc)
|
||||
self.sample_to_doc.pop()
|
||||
|
||||
self.num_docs = len(self.all_docs)
|
||||
|
||||
@ -159,20 +175,11 @@ class BERTDataset(Dataset):
|
||||
t2 = ""
|
||||
assert item < self.corpus_lines
|
||||
if self.on_memory:
|
||||
# get the right doc
|
||||
doc_id = 0
|
||||
doc_start = 0
|
||||
doc_end = len(self.all_docs[doc_id]) - 2
|
||||
while item > doc_end:
|
||||
doc_id += 1
|
||||
doc_start = doc_end + 1
|
||||
doc_end += len(self.all_docs[doc_id]) - 1
|
||||
# get the right line within doc
|
||||
line_in_doc = item - doc_start
|
||||
t1 = self.all_docs[doc_id][line_in_doc]
|
||||
t2 = self.all_docs[doc_id][line_in_doc + 1]
|
||||
sample = self.sample_to_doc[item]
|
||||
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
|
||||
t2 = self.all_docs[sample["doc_id"]][sample["line"]+1]
|
||||
# used later to avoid random nextSentence from same doc
|
||||
self.current_doc = doc_id
|
||||
self.current_doc = sample["doc_id"]
|
||||
return t1, t2
|
||||
else:
|
||||
if self.line_buffer is None:
|
||||
@ -297,6 +304,7 @@ def random_word(tokens, tokenizer):
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
output_label.append(tokenizer.vocab["[UNK]"])
|
||||
logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
output_label.append(-1)
|
||||
@ -468,17 +476,15 @@ def main():
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||
parser.add_argument('--optimize_on_cpu',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to perform optimization and keep the optimizer averages on CPU")
|
||||
parser.add_argument('--fp16',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=128,
|
||||
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
||||
type = float, default = 0,
|
||||
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -486,14 +492,13 @@ def main():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
if args.fp16:
|
||||
logger.info("16-bits training currently not supported in distributed training")
|
||||
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
|
||||
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
@ -531,29 +536,42 @@ def main():
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank)
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
model = DDP(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Prepare optimizer
|
||||
if args.fp16:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
elif args.optimize_on_cpu:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
else:
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
@ -580,33 +598,22 @@ def main():
|
||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# rescale loss for fp16 training
|
||||
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
|
||||
loss = loss * args.loss_scale
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16 or args.optimize_on_cpu:
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# scale down gradients for fp16 training
|
||||
for param in model.parameters():
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||
if is_nan:
|
||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||
args.loss_scale = args.loss_scale / 2
|
||||
model.zero_grad()
|
||||
continue
|
||||
optimizer.step()
|
||||
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
# modify learning rate with special warm up BERT uses
|
||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
@ -639,36 +646,5 @@ def accuracy(out, labels):
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
|
||||
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the parameters optimized on CPU/RAM back to the model on GPU
|
||||
"""
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
param_model.data.copy_(param_opti.data)
|
||||
|
||||
|
||||
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
|
||||
"""
|
||||
is_nan = False
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
if param_model.grad is not None:
|
||||
if test_nan and torch.isnan(param_model.grad).sum() > 0:
|
||||
is_nan = True
|
||||
if param_opti.grad is None:
|
||||
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
|
||||
param_opti.grad.data.copy_(param_model.grad.data)
|
||||
else:
|
||||
param_opti.grad = None
|
||||
return is_nan
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user