mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #777 from huggingface/examples
Working GLUE Example for XLNet (STS-B)
This commit is contained in:
commit
d216e798af
14
README.md
14
README.md
@ -1620,20 +1620,10 @@ and unpack it to some directory `$GLUE_DIR`.
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
|
||||
python run_xlnet_classifier.py \
|
||||
--task_name STS-B \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir $GLUE_DIR/STS-B/ \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 8 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python ./examples/run_glue.py --do_train --task_name=sts-b --data_dir=${GLUE_DIR}/STS-B --output_dir=./proc_data/sts-b-110 --max_seq_length=128 --per_gpu_eval_batch_size=8 --per_gpu_train_batch_size=8 --max_steps=1200 --model_name=xlnet-large-cased --overwrite_output_dir --overwrite_cache --warmup_steps=120
|
||||
```
|
||||
|
||||
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/zihangdai/xlnet#1-sts-b-sentence-pair-relevance-regression-with-gpus) gave evaluation results between 84% and 88%.
|
||||
This hyper-parameters give evaluation results pearsonr > 0.918.
|
||||
|
||||
### Distributed training
|
||||
|
||||
|
@ -1,528 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_bert import BertForSequenceClassification
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
parser.add_argument("--do_train",
|
||||
action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval",
|
||||
action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--train_batch_size",
|
||||
default=32,
|
||||
type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Total batch size for eval.")
|
||||
parser.add_argument("--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--no_cuda",
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir',
|
||||
action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--fp16',
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.device = device
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
task_name = args.task_name.lower()
|
||||
|
||||
if task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
output_mode = output_modes[task_name]
|
||||
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
global_step = 0
|
||||
nb_tr_steps = 0
|
||||
tr_loss = 0
|
||||
|
||||
if args.do_train:
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
# Prepare data loader
|
||||
train_examples = processor.get_train_examples(args.data_dir)
|
||||
cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
|
||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task_name)))
|
||||
try:
|
||||
with open(cached_train_features_file, "rb") as reader:
|
||||
train_features = pickle.load(reader)
|
||||
except:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "wb") as writer:
|
||||
pickle.dump(train_features, writer)
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||
loss = ouputs[0]
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||
|
||||
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
### Example:
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
|
||||
torch.save(args, output_args_file)
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
|
||||
model.to(device)
|
||||
|
||||
### Evaluation
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
|
||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task_name)))
|
||||
try:
|
||||
with open(cached_eval_features_file, "rb") as reader:
|
||||
eval_features = pickle.load(reader)
|
||||
except:
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||
with open(cached_eval_features_file, "wb") as writer:
|
||||
pickle.dump(eval_features, writer)
|
||||
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
if args.local_rank == -1:
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
else:
|
||||
eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
out_label_ids = None
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
out_label_ids = label_ids.detach().cpu().numpy()
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(
|
||||
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(task_name, preds, out_label_ids)
|
||||
|
||||
loss = tr_loss/global_step if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
# hack for MNLI-MM
|
||||
if task_name == "mnli":
|
||||
task_name = "mnli-mm"
|
||||
processor = processors[task_name]()
|
||||
|
||||
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir + '-MM'):
|
||||
os.makedirs(args.output_dir + '-MM')
|
||||
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
out_label_ids = None
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
out_label_ids = label_ids.detach().cpu().numpy()
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(
|
||||
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
result = compute_metrics(task_name, preds, out_label_ids)
|
||||
|
||||
loss = tr_loss/global_step if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -18,130 +18,135 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
|
||||
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||
XLMTokenizer)
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForSequenceClassification, BertTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetTokenizer)
|
||||
|
||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import (compute_metrics, convert_examples_to_features,
|
||||
output_modes, processors)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': BertForSequenceClassification,
|
||||
'xlnet': XLNetForSequenceClassification,
|
||||
'xlm': XLMForSequenceClassification,
|
||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
}
|
||||
|
||||
TOKENIZER_CLASSES = {
|
||||
'bert': BertTokenizer,
|
||||
'xlnet': XLNetTokenizer,
|
||||
'xlm': XLMTokenizer,
|
||||
}
|
||||
|
||||
def train(args, train_dataset, model):
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
num_train_optimization_steps = args.max_steps
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer, FusedAdam
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", num_train_optimization_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
ouputs = model(**inputs)
|
||||
loss = ouputs[0]
|
||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
loss.backward() if not args.fp16 else optimizer.backward(loss)
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
scheduler.step() # Update learning rate schedule
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
if args.local_rank in [-1, 0]:
|
||||
if not args.fp16:
|
||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
@ -150,59 +155,69 @@ def train(args, train_dataset, model):
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
||||
""" Evaluate the model """
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
""" Evaluate the model """
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * args.n_gpu
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
|
||||
return result
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif args.output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
@ -214,7 +229,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
if os.path.exists(cached_features_file):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
@ -259,6 +274,10 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
@ -270,39 +289,52 @@ def main():
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--train_batch_size", default=32, type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
||||
help="Total batch size for eval.")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale', type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
@ -328,7 +360,9 @@ def main():
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
@ -353,22 +387,23 @@ def main():
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
args.model_type = args.model_name.lower().split('-')[0]
|
||||
tokenizer_class = TOKENIZER_CLASSES[args.model_type]
|
||||
model_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name, num_labels=num_labels)
|
||||
args.model_type = ""
|
||||
for key in MODEL_CLASSES:
|
||||
if key in args.model_name.lower():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Distributed, parrallel and fp16 model
|
||||
if args.fp16:
|
||||
model.half()
|
||||
# Distributed and parrallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||
device_ids=[args.local_rank],
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
elif args.n_gpu > 1:
|
||||
@ -377,7 +412,7 @@ def main():
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
@ -387,6 +422,7 @@ def main():
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
@ -402,17 +438,22 @@ def main():
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Handle MNLI double evaluation
|
||||
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
results = {}
|
||||
for checkpoint in checkpoints:
|
||||
global_step = int(checkpoint.split('-')[-1])
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
result = evalutate(args, eval_task, eval_output_dir, eval_dataset, model)
|
||||
|
||||
return result
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -33,36 +33,156 @@ from tqdm import tqdm, trange
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_bert import BertForQuestionAnswering
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering,
|
||||
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||
XLMTokenizer)
|
||||
|
||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': BertForQuestionAnswering,
|
||||
'xlnet': XLNetForQuestionAnswering,
|
||||
'xlm': XLMForQuestionAnswering,
|
||||
}
|
||||
|
||||
TOKENIZER_CLASSES = {
|
||||
'bert': BertTokenizer,
|
||||
'xlnet': XLNetTokenizer,
|
||||
'xlm': XLMTokenizer,
|
||||
}
|
||||
|
||||
def train(args, train_dataset, model):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
num_train_optimization_steps = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate,
|
||||
t_total=num_train_optimization_steps, warmup=args.warmup_proportion)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Total batch size (distributed) = %d", args.train_batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", num_train_optimization_steps)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
ouputs = model(**inputs)
|
||||
loss = ouputs[0]
|
||||
|
||||
|
||||
def evalutate(args, dataset, model):
|
||||
""" Evaluate the model """
|
||||
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, training=True):
|
||||
""" Load data features from cache or dataset file. """
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
examples = read_squad_examples(input_file=args.train_file if training else args.predict_file,
|
||||
is_training=training,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=training)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Num orig examples = %d", len(examples))
|
||||
logger.info("Num split examples = %d", len(features))
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
if training:
|
||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||
else:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument('--version_2_with_negative', action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
|
||||
parser.add_argument("--max_seq_length", default=384, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
|
||||
"longer than this will be truncated, and sequences shorter than this will be padded.")
|
||||
@ -71,65 +191,53 @@ def main():
|
||||
parser.add_argument("--max_query_length", default=64, type=int,
|
||||
help="The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
|
||||
parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
|
||||
parser.add_argument("--train_batch_size", default=32, type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--predict_batch_size", default=8, type=int,
|
||||
help="Total batch size for predictions.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
||||
"output file.")
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
|
||||
parser.add_argument("--max_answer_length", default=30, type=int,
|
||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another.")
|
||||
parser.add_argument("--verbose_logging", action='store_true',
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||
parser.add_argument("--no_cuda",
|
||||
action='store_true',
|
||||
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--fp16',
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--overwrite_output_dir',
|
||||
action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument('--version_2_with_negative',
|
||||
action='store_true',
|
||||
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||
parser.add_argument('--null_score_diff_threshold',
|
||||
type=float, default=0.0,
|
||||
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
@ -137,71 +245,52 @@ def main():
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
# Setup logging
|
||||
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Setup seeds
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu > 0:
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if not args.do_train and not args.do_predict:
|
||||
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
||||
|
||||
if args.do_train:
|
||||
if not args.train_file:
|
||||
raise ValueError(
|
||||
"If `do_train` is True, then `train_file` must be specified.")
|
||||
if args.do_predict:
|
||||
if not args.predict_file:
|
||||
raise ValueError(
|
||||
"If `do_predict` is True, then `predict_file` must be specified.")
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory {} already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier() # Make sure only 1st process in distributed training download model & vocab
|
||||
|
||||
args.model_type = args.model_name.lower().split('-')[0]
|
||||
tokenizer_class = TOKENIZER_CLASSES[args.model_type]
|
||||
model_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name, num_labels=num_labels)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
# Distributed and parrallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||
device_ids=[args.local_rank],
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
elif n_gpu > 1:
|
||||
elif args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
@ -1,530 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_xlnet import XLNetForSequenceClassification
|
||||
from pytorch_transformers.tokenization_xlnet import XLNetTokenizer
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
# training
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0 limit the number of training steps to perform, you should choose only one of num_train_epochs and max_steps.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--clip_gradients", default=1.0, type=float,
|
||||
help="Clip gradient norms.")
|
||||
parser.add_argument("--train_batch_size", default=32, type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale', type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument("--log_every", default=10, type=int,
|
||||
help="Log metrics every X training steps.")
|
||||
# evaluation
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--eval_batch_size", default=8, type=int,
|
||||
help="Total batch size for eval.")
|
||||
# Model
|
||||
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
|
||||
help="XLNet pre-trained model: currently only xlnet-large-cased.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
# task specific
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
# Misc
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.device = device
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
task_name = args.task_name.lower()
|
||||
|
||||
if task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
output_mode = output_modes[task_name]
|
||||
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
tokenizer = XLNetTokenizer.from_pretrained(args.xlnet_model, do_lower_case=args.do_lower_case)
|
||||
model = XLNetForSequenceClassification.from_pretrained(args.xlnet_model, num_labels=num_labels)
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
global_step = 0
|
||||
curr_tr_loss, curr_steps = 0., 1
|
||||
|
||||
if args.do_train:
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
# Prepare data loader
|
||||
train_examples = processor.get_train_examples(args.data_dir)
|
||||
cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
|
||||
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task_name)))
|
||||
if os.path.exists(cached_train_features_file):
|
||||
logger.info("Loading train features for cache file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "rb") as reader:
|
||||
train_features = pickle.load(reader)
|
||||
else:
|
||||
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "wb") as writer:
|
||||
pickle.dump(train_features, writer)
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = SequentialSampler(train_data) # RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
num_train_optimization_steps = args.max_steps
|
||||
else:
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer
|
||||
|
||||
optimizer_grouped_parameters = model.parameters()
|
||||
# param_optimizer = list(model.named_parameters())
|
||||
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
# optimizer_grouped_parameters = [
|
||||
# {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
# {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
# ]
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex.optimizers import FP16_Optimizer
|
||||
from apex.optimizers import FusedAdam
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
bias_correction=False,
|
||||
max_grad_norm=1.0)
|
||||
if args.loss_scale == 0:
|
||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||
else:
|
||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs) if args.max_steps <= 0 else int('Inf'),
|
||||
desc="Epoch", disable=args.local_rank not in [-1, 0]):
|
||||
for step, batch in enumerate(tqdm(train_dataloader,
|
||||
desc="Iteration",
|
||||
disable=args.local_rank not in [-1, 0])):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
loss, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
|
||||
|
||||
curr_tr_loss += loss.item()
|
||||
curr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (global_step + 1) % args.log_every == 0):
|
||||
learning_rate = optimizer.get_lr()[0] if not args.fp16 else lr_this_step
|
||||
logger.info("[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f}".format(
|
||||
global_step, gnorm, learning_rate, curr_tr_loss / curr_steps))
|
||||
tb_writer.add_scalar('lr', learning_rate, global_step)
|
||||
tb_writer.add_scalar('loss', curr_tr_loss / curr_steps, global_step)
|
||||
curr_tr_loss, curr_steps = 0., 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
|
||||
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
### Example:
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = XLNetForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
||||
tokenizer = XLNetTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
|
||||
torch.save(args, output_args_file)
|
||||
else:
|
||||
model = XLNetForSequenceClassification.from_pretrained(args.xlnet_model, num_labels=num_labels)
|
||||
|
||||
model.to(device)
|
||||
|
||||
### Evaluation
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
|
||||
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task_name)))
|
||||
if os.path.exists(cached_eval_features_file):
|
||||
logger.info("Loading eval features for cache file %s", cached_eval_features_file)
|
||||
with open(cached_eval_features_file, "rb") as reader:
|
||||
eval_features = pickle.load(reader)
|
||||
else:
|
||||
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||
with open(cached_eval_features_file, "wb") as writer:
|
||||
pickle.dump(eval_features, writer)
|
||||
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
if args.local_rank == -1:
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
else:
|
||||
eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
out_label_ids = None
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
||||
|
||||
# create eval loss and other metric required by the task
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
elif output_mode == "regression":
|
||||
loss_fct = MSELoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
out_label_ids = label_ids.detach().cpu().numpy()
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(
|
||||
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(task_name, preds, out_label_ids)
|
||||
|
||||
loss = curr_tr_loss/curr_steps if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
# hack for MNLI-MM
|
||||
if task_name == "mnli":
|
||||
task_name = "mnli-mm"
|
||||
processor = processors[task_name]()
|
||||
|
||||
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir + '-MM'):
|
||||
os.makedirs(args.output_dir + '-MM')
|
||||
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
out_label_ids = None
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
out_label_ids = label_ids.detach().cpu().numpy()
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(
|
||||
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
result = compute_metrics(task_name, preds, out_label_ids)
|
||||
|
||||
loss = curr_tr_loss/curr_steps if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -21,6 +21,7 @@ import csv
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
@ -36,7 +36,7 @@ from .modeling_xlm import (XLMConfig, XLMModel,
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
|
||||
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
||||
|
@ -73,17 +73,17 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
@ -93,7 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
@ -113,7 +113,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
except AttributeError:
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
@ -127,7 +127,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
@ -49,17 +49,17 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array.squeeze())
|
||||
@ -90,7 +90,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
@ -110,7 +110,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
@ -126,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# Build TF to PyTorch weights loading map
|
||||
@ -136,7 +136,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
@ -157,7 +157,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
else:
|
||||
try:
|
||||
@ -165,13 +165,13 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
tf_weights.pop(name, None)
|
||||
tf_weights.pop(name + '/Adam', None)
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -272,7 +272,6 @@ class LogUniformSampler(object):
|
||||
self.range_max = range_max
|
||||
log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
# print('P', self.dist.numpy().tolist()[-30:])
|
||||
|
||||
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
||||
|
||||
@ -331,72 +330,3 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
|
||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# class LogUniformSampler(object):
|
||||
# def __init__(self, range_max, unique=False):
|
||||
# """
|
||||
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
# """
|
||||
# self.range_max = range_max
|
||||
# log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
|
||||
# self.unique = unique
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
|
||||
|
||||
# def sample(self, n_sample, labels):
|
||||
# pos_sample, new_labels = labels.unique(return_inverse=True)
|
||||
# n_pos_sample = pos_sample.size(0)
|
||||
# n_neg_sample = n_sample - n_pos_sample
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 1)
|
||||
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 0)
|
||||
# else:
|
||||
# sample_dist = self.dist
|
||||
|
||||
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
|
||||
|
||||
# sample = torch.cat([pos_sample, neg_sample])
|
||||
# sample_prob = self.dist[sample]
|
||||
|
||||
# return new_labels, sample, sample_prob
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
S, B = 3, 4
|
||||
n_vocab = 10000
|
||||
n_sample = 5
|
||||
H = 32
|
||||
|
||||
labels = torch.LongTensor(S, B).random_(0, n_vocab)
|
||||
|
||||
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||
|
||||
sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
|
||||
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||
|
||||
# print('true_probs', true_probs.numpy().tolist())
|
||||
# print('samp_probs', samp_probs.numpy().tolist())
|
||||
# print('neg_samples', neg_samples.numpy().tolist())
|
||||
|
||||
# print('sum', torch.sum(sampler.dist).item())
|
||||
|
||||
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
|
||||
|
||||
embedding = nn.Embedding(n_vocab, H)
|
||||
bias = torch.zeros(n_vocab)
|
||||
inputs = torch.Tensor(S, B, H).normal_()
|
||||
|
||||
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
|
||||
print('logits', logits.detach().numpy().tolist())
|
||||
print('logits shape', logits.size())
|
||||
print('out_labels', out_labels.detach().numpy().tolist())
|
||||
print('out_labels shape', out_labels.size())
|
||||
|
||||
|
@ -57,16 +57,18 @@ class PretrainedConfig(object):
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `xlnet-large-cased`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `config.json` a configuration file for the model
|
||||
- a path or url to a directory containing a configuration file `config.json` for the model,
|
||||
- a path or url to a configuration file for the model.
|
||||
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
else:
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
else:
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
@ -102,7 +104,7 @@ class PretrainedConfig(object):
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info("Model config {}".format(config))
|
||||
logger.info("Model config %s", config)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
@ -200,6 +202,7 @@ class PreTrainedModel(nn.Module):
|
||||
- a path or url to a tensorflow pretrained model checkpoint containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
config: an optional configuration for the model
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
|
||||
@ -207,23 +210,31 @@ class PreTrainedModel(nn.Module):
|
||||
*inputs, **kwargs: additional input for the specific XLNet class
|
||||
(ex: num_labels for XLNetForSequenceClassification)
|
||||
"""
|
||||
config = kwargs.pop('config', None)
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
if config is None:
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
else:
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
|
@ -122,14 +122,14 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
@ -137,15 +137,15 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
print("Importing {}".format(name))
|
||||
logger.info("Importing {}".format(name))
|
||||
if name not in tf_weights:
|
||||
print("{} not in tf pre-trained weights, skipping".format(name))
|
||||
logger.info("{} not in tf pre-trained weights, skipping".format(name))
|
||||
continue
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if 'kernel' in name and ('ff' in name or 'summary' in name or 'logit' in name):
|
||||
print("Transposing")
|
||||
logger.info("Transposing")
|
||||
array = np.transpose(array)
|
||||
if isinstance(pointer, list):
|
||||
# Here we will split the TF weigths
|
||||
@ -157,7 +157,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
else:
|
||||
try:
|
||||
@ -165,13 +165,13 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
tf_weights.pop(name, None)
|
||||
tf_weights.pop(name + '/Adam', None)
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
|
||||
@ -211,10 +211,6 @@ class XLNetConfig(PretrainedConfig):
|
||||
layer_norm_eps=1e-12,
|
||||
|
||||
dropout=0.1,
|
||||
dropatt=0.1,
|
||||
init="normal",
|
||||
init_range=0.1,
|
||||
init_std=0.02,
|
||||
mem_len=None,
|
||||
reuse_len=None,
|
||||
bi_data=False,
|
||||
@ -258,11 +254,6 @@ class XLNetConfig(PretrainedConfig):
|
||||
|
||||
dropout: float, dropout rate.
|
||||
dropatt: float, dropout rate on attention probabilities.
|
||||
init: str, the initialization scheme, either "normal" or "uniform".
|
||||
init_range: float, initialize the parameters with a uniform distribution
|
||||
in [-init_range, init_range]. Only effective when init="uniform".
|
||||
init_std: float, initialize the parameters with a normal distribution
|
||||
with mean 0 and stddev init_std. Only effective when init="normal".
|
||||
mem_len: int, the number of tokens to cache.
|
||||
reuse_len: int, the number of tokens in the currect batch to be cached
|
||||
and reused in the future.
|
||||
@ -297,11 +288,7 @@ class XLNetConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.init = init
|
||||
self.init_range = init_range
|
||||
self.init_std = init_std
|
||||
self.dropout = dropout
|
||||
self.dropatt = dropatt
|
||||
self.mem_len = mem_len
|
||||
self.reuse_len = reuse_len
|
||||
self.bi_data = bi_data
|
||||
@ -393,7 +380,7 @@ class XLNetRelativeAttention(nn.Module):
|
||||
x = x[1:, ...]
|
||||
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
|
||||
# x = x[:, 0:klen, :, :]
|
||||
x = torch.index_select(x, 1, torch.arange(klen))
|
||||
x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
|
||||
|
||||
return x
|
||||
|
||||
|
@ -14,174 +14,92 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for BERT model."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
import abc
|
||||
import sys
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConstantLRSchedule(LambdaLR):
|
||||
def __init__(self, optimizer, last_epoch=-1):
|
||||
super(ConstantLRSchedule, self).__init__(optimizer, lambda x: x, last_epoch=last_epoch)
|
||||
|
||||
if sys.version_info >= (3, 4):
|
||||
ABC = abc.ABC
|
||||
else:
|
||||
ABC = abc.ABCMeta('ABC', (), {})
|
||||
|
||||
|
||||
class _LRSchedule(ABC):
|
||||
""" Parent of all LRSchedules here. """
|
||||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
||||
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
||||
"""
|
||||
:param warmup: what fraction of t_total steps will be used for linear warmup
|
||||
:param t_total: how many training steps (updates) are planned
|
||||
:param kw:
|
||||
"""
|
||||
super(_LRSchedule, self).__init__(**kw)
|
||||
if t_total < 0:
|
||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
warmup = max(warmup, 0.)
|
||||
self.warmup, self.t_total = float(warmup), float(t_total)
|
||||
self.warned_for_t_total_at_progress = -1
|
||||
|
||||
def get_lr(self, step, nowarn=False):
|
||||
"""
|
||||
:param step: which of t_total steps we're on
|
||||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
if self.t_total < 0:
|
||||
return 1.
|
||||
progress = float(step) / self.t_total
|
||||
ret = self.get_lr_(progress)
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
|
||||
.format(ret, self.__class__.__name__))
|
||||
self.warned_for_t_total_at_progress = progress
|
||||
# end warning
|
||||
return ret
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_lr_(self, progress):
|
||||
"""
|
||||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
return 1.
|
||||
|
||||
|
||||
class ConstantLR(_LRSchedule):
|
||||
def get_lr_(self, progress):
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupCosineSchedule(_LRSchedule):
|
||||
class WarmupCosineSchedule(LambdaLR):
|
||||
"""
|
||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
|
||||
Linearly increases learning rate from 0 to 1 over `warmup` training steps.
|
||||
Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
|
||||
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
||||
:param warmup: see LRSchedule
|
||||
:param t_total: see LRSchedule
|
||||
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
||||
:param kw:
|
||||
"""
|
||||
warn_t_total = True
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
||||
"""
|
||||
:param warmup: see LRSchedule
|
||||
:param t_total: see LRSchedule
|
||||
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
||||
:param kw:
|
||||
"""
|
||||
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
||||
self.cycles = cycles
|
||||
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
||||
def lr_lambda(step):
|
||||
if step < warmup_steps:
|
||||
return step / max(1, warmup_steps)
|
||||
else:
|
||||
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
|
||||
return 0.5 * (1. + math.cos(math.pi * cycles * 2 * progress))
|
||||
|
||||
super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
||||
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
|
||||
"""
|
||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
|
||||
learning rate (with hard restarts).
|
||||
"""
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||
assert(cycles >= 1.)
|
||||
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
|
||||
return ret
|
||||
def lr_lambda(step):
|
||||
if step < warmup_steps:
|
||||
return step / max(1, warmup_steps)
|
||||
else:
|
||||
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
|
||||
ret = 0.5 * (1. + math.cos(math.pi * ((cycles * progress) % 1)))
|
||||
return ret
|
||||
|
||||
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
|
||||
"""
|
||||
All training progress is divided in `cycles` (default=1.) parts of equal length.
|
||||
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
|
||||
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
|
||||
"""
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||
assert(warmup * cycles < 1.)
|
||||
warmup = warmup * cycles if warmup >= 0 else warmup
|
||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||
|
||||
def get_lr_(self, progress):
|
||||
progress = progress * self.cycles % 1.
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
||||
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
||||
return ret
|
||||
|
||||
|
||||
class WarmupConstantSchedule(_LRSchedule):
|
||||
class WarmupConstantSchedule(LambdaLR):
|
||||
"""
|
||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||
Keeps learning rate equal to 1. after warmup.
|
||||
"""
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
return 1.
|
||||
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
|
||||
def lr_lambda(step):
|
||||
if step < warmup_steps:
|
||||
return step / warmup_steps
|
||||
return 1.
|
||||
|
||||
super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
class WarmupLinearSchedule(_LRSchedule):
|
||||
class WarmupLinearSchedule(LambdaLR):
|
||||
"""
|
||||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
||||
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
|
||||
"""
|
||||
warn_t_total = True
|
||||
def get_lr_(self, progress):
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
||||
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
||||
|
||||
def lr_lambda(step):
|
||||
if step < warmup_steps:
|
||||
return step / max(1, warmup_steps)
|
||||
return (t_total - step) / max(1, t_total - warmup_steps)
|
||||
|
||||
super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
SCHEDULES = {
|
||||
None: ConstantLR,
|
||||
"none": ConstantLR,
|
||||
"warmup_cosine": WarmupCosineSchedule,
|
||||
"warmup_constant": WarmupConstantSchedule,
|
||||
"warmup_linear": WarmupLinearSchedule
|
||||
}
|
||||
|
||||
|
||||
class BertAdam(Optimizer):
|
||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
||||
class AdamW(Optimizer):
|
||||
""" Implements Adam algorithm with weight decay fix.
|
||||
|
||||
Parameters:
|
||||
lr: learning rate
|
||||
@ -197,43 +115,20 @@ class BertAdam(Optimizer):
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, correct_bias=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1] ))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
if warmup != -1 or t_total != -1:
|
||||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
|
||||
"Please specify custom warmup and t_total in _LRSchedule object.")
|
||||
defaults = dict(lr=lr, schedule=schedule,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
correct_bias=correct_bias)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
@ -260,22 +155,28 @@ class BertAdam(Optimizer):
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['next_m'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['next_v'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
next_m, next_v = state['next_m'], state['next_v']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
next_m.mul_(beta1).add_(1 - beta1, grad)
|
||||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
update = next_m / (next_v.sqrt() + group['e'])
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
step_size = group['lr']
|
||||
if group['correct_bias']: # No bias correction for Bert
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
@ -284,20 +185,8 @@ class BertAdam(Optimizer):
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
# No bias correction
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group['weight_decay'] > 0:
|
||||
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
||||
|
||||
return loss
|
||||
|
@ -1,127 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for OpenAI GPT model."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
|
||||
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAdam(Optimizer):
|
||||
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
||||
"""
|
||||
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
|
||||
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
|
||||
vector_l2=False, max_grad_norm=-1, **kwargs):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
# initialize schedule object
|
||||
if not isinstance(schedule, _LRSchedule):
|
||||
schedule_type = SCHEDULES[schedule]
|
||||
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
||||
else:
|
||||
if warmup != -1 or t_total != -1:
|
||||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
|
||||
"Please specify custom warmup and t_total in _LRSchedule object.")
|
||||
defaults = dict(lr=lr, schedule=schedule,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(OpenAIAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
denom = exp_avg_sq.sqrt().add_(group['e'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
lr_scheduled = group['lr']
|
||||
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
||||
|
||||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
# Add weight decay at the end (fixed version)
|
||||
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
|
||||
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
|
||||
|
||||
return loss
|
@ -20,10 +20,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import BertAdam
|
||||
from pytorch_transformers import OpenAIAdam
|
||||
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
|
||||
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -34,12 +33,12 @@ class OptimizationTest(unittest.TestCase):
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def test_adam(self):
|
||||
def test_adam_w(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||
optimizer = AdamW(params=[w], lr=2e-1,
|
||||
weight_decay=0.0,
|
||||
max_grad_norm=-1)
|
||||
for _ in range(100):
|
||||
@ -52,23 +51,13 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_bert_sched_init(self):
|
||||
def test_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
def test_openai_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
|
@ -98,14 +98,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
self.build_vocab()
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
if verbose: logger.info('counting file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
|
||||
sents = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
logger.info(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos)
|
||||
self.counter.update(symbols)
|
||||
sents.append(symbols)
|
||||
@ -116,10 +116,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
sents : a list of sentences, each a list of tokenized symbols
|
||||
"""
|
||||
if verbose: print('counting {} sents ...'.format(len(sents)))
|
||||
if verbose: logger.info('counting {} sents ...'.format(len(sents)))
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
logger.info(' line {}'.format(idx))
|
||||
self.counter.update(symbols)
|
||||
|
||||
def _build_from_file(self, vocab_file):
|
||||
@ -147,11 +147,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
logger.info('building vocab from {}'.format(self.vocab_file))
|
||||
self._build_from_file(self.vocab_file)
|
||||
print('final vocab size {}'.format(len(self)))
|
||||
logger.info('final vocab size {}'.format(len(self)))
|
||||
else:
|
||||
print('building vocab with min_freq={}, max_size={}'.format(
|
||||
logger.info('building vocab with min_freq={}, max_size={}'.format(
|
||||
self.min_freq, self.max_size))
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
@ -163,18 +163,18 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
if cnt < self.min_freq: break
|
||||
self.add_symbol(sym)
|
||||
|
||||
print('final vocab size {} from {} unique tokens'.format(
|
||||
logger.info('final vocab size {} from {} unique tokens'.format(
|
||||
len(self), len(self.counter)))
|
||||
|
||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
|
||||
add_double_eos=False):
|
||||
if verbose: print('encoding file {} ...'.format(path))
|
||||
if verbose: logger.info('encoding file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
encoded = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
logger.info(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos,
|
||||
add_double_eos=add_double_eos)
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
@ -185,11 +185,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
return encoded
|
||||
|
||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||
if verbose: print('encoding {} sents ...'.format(len(sents)))
|
||||
if verbose: logger.info('encoding {} sents ...'.format(len(sents)))
|
||||
encoded = []
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
logger.info(' line {}'.format(idx))
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
|
||||
if ordered:
|
||||
@ -218,7 +218,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
if sym in self.sym2idx:
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
# print('encounter unk {}'.format(sym))
|
||||
# logger.info('encounter unk {}'.format(sym))
|
||||
# assert '<eos>' not in sym
|
||||
if hasattr(self, 'unk_idx'):
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
@ -544,14 +544,14 @@ def get_lm_corpus(datadir, dataset):
|
||||
fn = os.path.join(datadir, 'cache.pt')
|
||||
fn_pickle = os.path.join(datadir, 'cache.pkl')
|
||||
if os.path.exists(fn):
|
||||
print('Loading cached dataset...')
|
||||
logger.info('Loading cached dataset...')
|
||||
corpus = torch.load(fn_pickle)
|
||||
elif os.path.exists(fn):
|
||||
print('Loading cached dataset from pickle...')
|
||||
logger.info('Loading cached dataset from pickle...')
|
||||
with open(fn, "rb") as fp:
|
||||
corpus = pickle.load(fp)
|
||||
else:
|
||||
print('Producing dataset {}...'.format(dataset))
|
||||
logger.info('Producing dataset {}...'.format(dataset))
|
||||
kwargs = {}
|
||||
if dataset in ['wt103', 'wt2']:
|
||||
kwargs['special'] = ['<eos>']
|
||||
|
Loading…
Reference in New Issue
Block a user