mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
fix swag example for work with apex
This commit is contained in:
parent
0cf88ff084
commit
0f544625f4
@ -442,12 +442,12 @@ python run_swag.py \
|
|||||||
--do_train \
|
--do_train \
|
||||||
--do_lower_case \
|
--do_lower_case \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--data_dir $SWAG_DIR/data
|
--data_dir $SWAG_DIR/data \
|
||||||
--train_batch_size 16 \
|
--train_batch_size 16 \
|
||||||
--learning_rate 2e-5 \
|
--learning_rate 2e-5 \
|
||||||
--num_train_epochs 3.0 \
|
--num_train_epochs 3.0 \
|
||||||
--max_seq_length 80 \
|
--max_seq_length 80 \
|
||||||
--output_dir /tmp/swag_output/
|
--output_dir /tmp/swag_output/ \
|
||||||
--gradient_accumulation_steps 4
|
--gradient_accumulation_steps 4
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -232,34 +233,10 @@ def select_field(features, field):
|
|||||||
for feature in features
|
for feature in features
|
||||||
]
|
]
|
||||||
|
|
||||||
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
|
def warmup_linear(x, warmup=0.002):
|
||||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
if x < warmup:
|
||||||
Copy the parameters optimized on CPU/RAM back to the model on GPU
|
return x/warmup
|
||||||
"""
|
return 1.0 - x
|
||||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
|
||||||
if name_opti != name_model:
|
|
||||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
|
||||||
raise ValueError
|
|
||||||
param_model.data.copy_(param_opti.data)
|
|
||||||
|
|
||||||
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
|
|
||||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
|
||||||
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
|
|
||||||
"""
|
|
||||||
is_nan = False
|
|
||||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
|
||||||
if name_opti != name_model:
|
|
||||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
|
||||||
raise ValueError
|
|
||||||
if param_model.grad is not None:
|
|
||||||
if test_nan and torch.isnan(param_model.grad).sum() > 0:
|
|
||||||
is_nan = True
|
|
||||||
if param_opti.grad is None:
|
|
||||||
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
|
|
||||||
param_opti.grad.data.copy_(param_model.grad.data)
|
|
||||||
else:
|
|
||||||
param_opti.grad = None
|
|
||||||
return is_nan
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -335,17 +312,15 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
parser.add_argument('--optimize_on_cpu',
|
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help="Whether to perform optimization and keep the optimizer averages on CPU")
|
|
||||||
parser.add_argument('--fp16',
|
parser.add_argument('--fp16',
|
||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||||
parser.add_argument('--loss_scale',
|
parser.add_argument('--loss_scale',
|
||||||
type=float, default=128,
|
type=float, default=0,
|
||||||
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -353,13 +328,11 @@ def main():
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
n_gpu = torch.cuda.device_count()
|
n_gpu = torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
|
torch.cuda.set_device(args.local_rank)
|
||||||
device = torch.device("cuda", args.local_rank)
|
device = torch.device("cuda", args.local_rank)
|
||||||
n_gpu = 1
|
n_gpu = 1
|
||||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
if args.fp16:
|
|
||||||
logger.info("16-bits training currently not supported in distributed training")
|
|
||||||
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
|
|
||||||
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||||
|
|
||||||
if args.gradient_accumulation_steps < 1:
|
if args.gradient_accumulation_steps < 1:
|
||||||
@ -399,32 +372,50 @@ def main():
|
|||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
try:
|
||||||
output_device=args.local_rank)
|
from apex.parallel import DistributedDataParallel as DDP
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
|
model = DDP(model)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
if args.fp16:
|
param_optimizer = list(model.named_parameters())
|
||||||
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
|
|
||||||
for n, param in model.named_parameters()]
|
# hack to remove pooler, which is not used
|
||||||
elif args.optimize_on_cpu:
|
# thus it produce None grad that break apex
|
||||||
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
|
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||||
for n, param in model.named_parameters()]
|
|
||||||
else:
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
param_optimizer = list(model.named_parameters())
|
|
||||||
no_decay = ['bias', 'gamma', 'beta']
|
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
t_total = num_train_steps
|
t_total = num_train_steps
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
if args.fp16:
|
||||||
lr=args.learning_rate,
|
try:
|
||||||
warmup=args.warmup_proportion,
|
from apex.optimizers import FP16_Optimizer
|
||||||
t_total=t_total)
|
from apex.optimizers import FusedAdam
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
|
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
bias_correction=False,
|
||||||
|
max_grad_norm=1.0)
|
||||||
|
if args.loss_scale == 0:
|
||||||
|
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||||
|
else:
|
||||||
|
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||||
|
else:
|
||||||
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
warmup=args.warmup_proportion,
|
||||||
|
t_total=t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
@ -461,28 +452,18 @@ def main():
|
|||||||
loss = loss * args.loss_scale
|
loss = loss * args.loss_scale
|
||||||
if args.gradient_accumulation_steps > 1:
|
if args.gradient_accumulation_steps > 1:
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
loss.backward()
|
|
||||||
tr_loss += loss.item()
|
if args.fp16:
|
||||||
nb_tr_examples += input_ids.size(0)
|
optimizer.backward(loss)
|
||||||
nb_tr_steps += 1
|
else:
|
||||||
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16 or args.optimize_on_cpu:
|
# modify learning rate with special warm up BERT uses
|
||||||
if args.fp16 and args.loss_scale != 1.0:
|
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
|
||||||
# scale down gradients for fp16 training
|
for param_group in optimizer.param_groups:
|
||||||
for param in model.parameters():
|
param_group['lr'] = lr_this_step
|
||||||
if param.grad is not None:
|
optimizer.step()
|
||||||
param.grad.data = param.grad.data / args.loss_scale
|
optimizer.zero_grad()
|
||||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
|
||||||
if is_nan:
|
|
||||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
|
||||||
args.loss_scale = args.loss_scale / 2
|
|
||||||
model.zero_grad()
|
|
||||||
continue
|
|
||||||
optimizer.step()
|
|
||||||
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
|
|
||||||
else:
|
|
||||||
optimizer.step()
|
|
||||||
model.zero_grad()
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
# No warmup, constant schedule, no gradient clipping
|
# No warmup, constant schedule, no gradient clipping
|
||||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||||
weight_decay_rate=0.0,
|
weight_decay=0.0,
|
||||||
max_grad_norm=-1)
|
max_grad_norm=-1)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
loss = criterion(w, target)
|
loss = criterion(w, target)
|
||||||
|
Loading…
Reference in New Issue
Block a user