option to perform optimization and keep the optimizer averages on CPU

This commit is contained in:
thomwolf 2018-11-09 11:28:14 +01:00
parent 9e95cd8cd6
commit 5f04aa00ed

View File

@ -719,7 +719,6 @@ def main():
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--verbose_logging", default=False, action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
@ -727,10 +726,6 @@ def main():
default=False,
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
@ -738,7 +733,16 @@ def main():
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--optimize_on_cpu',
default=False,
action='store_true',
help="Whether to perform optimization and keep the optimizer averages on CPU")
args = parser.parse_args()
@ -802,25 +806,26 @@ def main():
model = BertForQuestionAnswering(bert_config)
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if not args.optimize_on_cpu:
model.to(device)
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
@ -862,8 +867,12 @@ def main():
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.optimize_on_cpu:
model.to('cpu')
optimizer.step() # We have accumulated enought gradients
model.zero_grad()
if args.optimize_on_cpu:
model.to(device)
global_step += 1
if args.do_predict: