add fp16 training

This commit is contained in:
thomwolf 2018-11-12 15:15:02 +01:00
parent 5dfd19060a
commit 66b0090877
4 changed files with 157 additions and 52 deletions

View File

@ -100,19 +100,20 @@ python -m pytest -sv tests/
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
To help with fine-tuning these models, we have included four techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: optimize on CPU, gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
Here is how to use these techniques in our scripts:
- **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script.
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details):
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument.
- **Optimize on CPU**: The Adam optimizer comprise 2 moving average of all the weights of the model which means that if you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal when using a large model like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU to free more room on the GPU(s). As the most computational intensive operation is the backward pass, this usually doesn't increase the computation time by a lot. This is the only way to fine-tune `BERT-large` in a reasonable time on GPU(s) (see below). Activate this option with `--optimize_on_cpu` on the `run_squad.py` script.
- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by a factor of 2, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. 16bits training natively incoporate the behavior of `--optimize_on_gpu` so it's not needed to have the two flags at the same time. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests.
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
```bash
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
```
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
## TPU support and pretraining scripts
@ -215,7 +216,7 @@ To get these results we used a combination of:
- 2 steps of gradient accumulation and
- perform the optimization step on CPU to store Adam's averages in RAM.
Here are the full list of hyper-parameters for this run:
Here is the full list of hyper-parameters for this run:
```bash
python ./run_squad.py \
--vocab_file $BERT_LARGE_DIR/vocab.txt \

View File

@ -348,7 +348,7 @@ class BertModel(nn.Module):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.float()
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)

View File

@ -309,6 +309,32 @@ def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the parameters optimized on CPU/RAM back to the model on GPU
"""
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
param_model.data.copy_(param_opti.data)
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
"""
is_nan = False
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
if test_nan and torch.isnan(param_model.grad).sum() > 0:
is_nan = True
if param_opti.grad is None:
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
param_opti.grad.data.copy_(param_model.grad.data)
return is_nan
def main():
parser = argparse.ArgumentParser()
@ -404,6 +430,18 @@ def main():
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
parser.add_argument('--optimize_on_cpu',
default=False,
action='store_true',
help="Whether to perform optimization and keep the optimizer averages on CPU")
parser.add_argument('--fp16',
default=False,
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=128,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
args = parser.parse_args()
processors = {
@ -420,6 +458,9 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
if args.fp16:
logger.info("16-bits training currently not supported in distributed training")
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
if args.gradient_accumulation_steps < 1:
@ -466,24 +507,34 @@ def main():
num_train_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.fp16:
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
for n, param in model.named_parameters()]
elif args.optimize_on_cpu:
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
for n, param in model.named_parameters()]
else:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_parameters,
optimizer = BERTAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
@ -496,12 +547,10 @@ def main():
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
@ -519,6 +568,10 @@ def main():
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0:
# rescale loss for fp16 training
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
@ -526,7 +579,21 @@ def main():
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step() # We have accumulated enought gradients
if args.fp16 or args.optimize_on_cpu:
if args.fp16 and args.loss_scale != 1.0:
# scale down gradients for fp16 training
for param in model.parameters():
param.grad.data = param.grad.data / args.loss_scale
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
if is_nan:
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
args.loss_scale = args.loss_scale / 2
model.zero_grad()
continue
optimizer.step()
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
else:
optimizer.step()
model.zero_grad()
global_step += 1
@ -534,16 +601,13 @@ def main():
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)

View File

@ -669,6 +669,31 @@ def _compute_softmax(scores):
probs.append(score / total_sum)
return probs
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the parameters optimized on CPU/RAM back to the model on GPU
"""
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
param_model.data.copy_(param_opti.data)
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
""" Utility function for optimize_on_cpu and 16-bits training.
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
"""
is_nan = False
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
if test_nan and torch.isnan(param_model.grad).sum() > 0:
is_nan = True
if param_opti.grad is None:
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
param_opti.grad.data.copy_(param_model.grad.data)
return is_nan
def main():
parser = argparse.ArgumentParser()
@ -746,7 +771,9 @@ def main():
default=False,
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=128,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
args = parser.parse_args()
@ -758,7 +785,11 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
if args.fp16:
logger.info("16-bits training currently not supported in distributed training")
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits trainiing: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
@ -807,24 +838,12 @@ def main():
num_train_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForQuestionAnswering(bert_config)
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
if args.fp16:
model.half()
if not args.optimize_on_cpu:
model.to(device)
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
@ -832,6 +851,25 @@ def main():
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.fp16:
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
for n, param in model.named_parameters()]
elif args.optimize_on_cpu:
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
for n, param in model.named_parameters()]
else:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
@ -846,19 +884,11 @@ def main():
logger.info(" Num split examples = %d", len(train_features))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
if args.fp16:
(all_input_ids, all_input_mask,
all_segment_ids, all_start_positions,
all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions))
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions)
if args.local_rank == -1:
@ -870,21 +900,36 @@ def main():
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0:
# rescale loss for fp16 training
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.optimize_on_cpu:
model.to('cpu')
optimizer.step() # We have accumulated enought gradients
if args.fp16 or args.optimize_on_cpu:
if args.fp16 and args.loss_scale != 1.0:
# scale down gradients for fp16 training
for param in model.parameters():
param.grad.data = param.grad.data / args.loss_scale
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
if is_nan:
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
args.loss_scale = args.loss_scale / 2
model.zero_grad()
continue
optimizer.step()
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
else:
optimizer.step()
model.zero_grad()
if args.optimize_on_cpu:
model.to(device)
global_step += 1
if args.do_predict:
@ -907,11 +952,6 @@ def main():
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
if args.fp16:
(all_input_ids, all_input_mask,
all_segment_ids, all_example_index) = tuple(t.half() for t in (all_input_ids, all_input_mask,
all_segment_ids, all_example_index))
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)