mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
update apex fp16 implementation
This commit is contained in:
parent
caf1d116a6
commit
c8731b9583
@ -1,22 +1,20 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import numpy as np
|
||||
from argparse import ArgumentParser
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_bert import BertForPreTraining
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
|
||||
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
||||
|
||||
@ -72,16 +70,16 @@ class PregeneratedDataset(Dataset):
|
||||
if reduce_memory:
|
||||
self.temp_dir = TemporaryDirectory()
|
||||
self.working_dir = Path(self.temp_dir.name)
|
||||
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
|
||||
input_ids = np.memmap(filename=self.working_dir / 'input_ids.memmap',
|
||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||
input_masks = np.memmap(filename=self.working_dir / 'input_masks.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
|
||||
segment_ids = np.memmap(filename=self.working_dir / 'segment_ids.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||
lm_label_ids = np.memmap(filename=self.working_dir / 'lm_label_ids.memmap',
|
||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||
lm_label_ids[:] = -1
|
||||
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
|
||||
is_nexts = np.memmap(filename=self.working_dir / 'is_nexts.memmap',
|
||||
shape=(num_samples,), mode='w+', dtype=np.bool)
|
||||
else:
|
||||
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
||||
@ -125,7 +123,8 @@ def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
||||
parser.add_argument('--output_dir', type=Path, required=True)
|
||||
parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
parser.add_argument("--bert_model", type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||
parser.add_argument("--do_lower_case", action="store_true")
|
||||
parser.add_argument("--reduce_memory", action="store_true",
|
||||
@ -153,14 +152,14 @@ def main():
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument("--warmup_steps",
|
||||
default=0,
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument("--warmup_steps",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--adam_epsilon",
|
||||
default=1e-8,
|
||||
parser.add_argument("--adam_epsilon",
|
||||
default=1e-8,
|
||||
type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--learning_rate",
|
||||
@ -207,7 +206,7 @@ def main():
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
@ -235,8 +234,9 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
# We don't need to manually call model.half() following Apex's recommend
|
||||
# if args.fp16:
|
||||
# model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
try:
|
||||
@ -257,25 +257,36 @@ def main():
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
# from apex.optimizers import FP16_Optimizer
|
||||
# from apex.optimizers import FusedAdam
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
else:
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
|
||||
# This below line of code is the main upgrade of Apex Fp16 implementation. I chose opt_leve="01"
|
||||
# because it's recommended for typical use by Apex. We can make it configured
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||
|
||||
# We don't need to use FP16_Optimizer wrapping over FusedAdam as well. Now Apex supports all Pytorch Optimizer
|
||||
|
||||
# optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
# lr=args.learning_rate,
|
||||
# bias_correction=False,
|
||||
# max_grad_norm=1.0)
|
||||
# if args.loss_scale == 0:
|
||||
# optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
# else:
|
||||
# optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
# else:
|
||||
# optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
|
||||
|
||||
global_step = 0
|
||||
logging.info("***** Running training *****")
|
||||
@ -300,11 +311,14 @@ def main():
|
||||
outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||
loss = outputs[0]
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
# I depricate FP16_Optimizer's backward func and replace as Apex document
|
||||
# optimizer.backward(loss)
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
@ -322,7 +336,8 @@ def main():
|
||||
# Save a trained model
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model,
|
||||
'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user