Merge pull request #777 from huggingface/examples

Working GLUE Example for XLNet (STS-B)
This commit is contained in:
Thomas Wolf 2019-07-11 15:43:47 +02:00 committed by GitHub
commit d216e798af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 529 additions and 1787 deletions

View File

@ -1620,20 +1620,10 @@ and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=/path/to/glue
python run_xlnet_classifier.py \
--task_name STS-B \
--do_train \
--do_eval \
--data_dir $GLUE_DIR/STS-B/ \
--max_seq_length 128 \
--train_batch_size 8 \
--gradient_accumulation_steps 1 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
CUDA_VISIBLE_DEVICES=0,1,2,3 python ./examples/run_glue.py --do_train --task_name=sts-b --data_dir=${GLUE_DIR}/STS-B --output_dir=./proc_data/sts-b-110 --max_seq_length=128 --per_gpu_eval_batch_size=8 --per_gpu_train_batch_size=8 --max_steps=1200 --model_name=xlnet-large-cased --overwrite_output_dir --overwrite_cache --warmup_steps=120
```
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/zihangdai/xlnet#1-sts-b-sentence-pair-relevance-regression-with-gpus) gave evaluation results between 84% and 88%.
This hyper-parameters give evaluation results pearsonr > 0.918.
### Distributed training

View File

@ -1,528 +0,0 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import os
import sys
import random
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss
from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForSequenceClassification
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
args.device = device
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
task_name = args.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
global_step = 0
nb_tr_steps = 0
tr_loss = 0
if args.do_train:
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
# Prepare data loader
train_examples = processor.get_train_examples(args.data_dir)
cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
list(filter(None, args.bert_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
loss = ouputs[0]
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
### Example:
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
# Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device)
### Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
list(filter(None, args.bert_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
try:
with open(cached_eval_features_file, "rb") as reader:
eval_features = pickle.load(reader)
except:
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
with open(cached_eval_features_file, "wb") as writer:
pickle.dump(eval_features, writer)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
out_label_ids = None
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
out_label_ids = label_ids.detach().cpu().numpy()
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
if output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, out_label_ids)
loss = tr_loss/global_step if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
# hack for MNLI-MM
if task_name == "mnli":
task_name = "mnli-mm"
processor = processors[task_name]()
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir + '-MM'):
os.makedirs(args.output_dir + '-MM')
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
out_label_ids = None
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
out_label_ids = label_ids.detach().cpu().numpy()
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, out_label_ids)
loss = tr_loss/global_step if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
main()

View File

@ -18,130 +18,135 @@
from __future__ import absolute_import, division, print_function
import argparse
import glob
import logging
import os
import random
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForSequenceClassification, BertTokenizer,
XLMConfig, XLMForSequenceClassification,
XLMTokenizer, XLNetConfig,
XLNetForSequenceClassification,
XLNetTokenizer)
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_glue import (compute_metrics, convert_examples_to_features,
output_modes, processors)
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': BertForSequenceClassification,
'xlnet': XLNetForSequenceClassification,
'xlm': XLMForSequenceClassification,
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def train(args, train_dataset, model):
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
num_train_optimization_steps = args.max_steps
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer, FusedAdam
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", num_train_optimization_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss = 0
model.train()
optimizer.zero_grad()
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0]
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward() if not args.fp16 else optimizer.backward(loss)
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
scheduler.step() # Update learning rate schedule
optimizer.step()
optimizer.zero_grad()
model.zero_grad()
global_step += 1
if args.local_rank in [-1, 0]:
if not args.fp16:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
break
if args.max_steps > 0 and global_step > args.max_steps:
@ -150,59 +155,69 @@ def train(args, train_dataset, model):
return global_step, tr_loss / global_step
def evalutate(args, eval_task, eval_output_dir, dataset, model):
""" Evaluate the model """
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
def evaluate(args, model, tokenizer, prefix=""):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
# Eval!
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = None
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(args.device) for t in batch)
""" Evaluate the model """
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
args.eval_batch_size = args.per_gpu_eval_batch_size * args.n_gpu
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs['labels'].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0
nb_eval_steps = 0
preds = None
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs['labels'].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
return result
eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return results
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
@ -214,7 +229,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length),
str(task)))
if os.path.exists(cached_features_file):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
@ -259,6 +274,10 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int,
@ -270,39 +289,52 @@ def main():
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU for evaluation.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.")
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--overwrite_cache', action='store_true',
help="Overwrite the cached training and evaluation sets")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
@ -328,7 +360,9 @@ def main():
args.device = device
# Setup logging
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
@ -353,22 +387,23 @@ def main():
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.model_type = args.model_name.lower().split('-')[0]
tokenizer_class = TOKENIZER_CLASSES[args.model_type]
model_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, num_labels=num_labels)
args.model_type = ""
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
if args.local_rank == 0:
torch.distributed.barrier()
# Distributed, parrallel and fp16 model
if args.fp16:
model.half()
# Distributed and parrallel training
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif args.n_gpu > 1:
@ -377,7 +412,7 @@ def main():
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
@ -387,6 +422,7 @@ def main():
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
@ -402,17 +438,22 @@ def main():
model.to(args.device)
# Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Handle MNLI double evaluation
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
results = {}
for checkpoint in checkpoints:
global_step = int(checkpoint.split('-')[-1])
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result)
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
result = evalutate(args, eval_task, eval_output_dir, eval_dataset, model)
return result
return results
if __name__ == "__main__":

View File

@ -33,36 +33,156 @@ from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForQuestionAnswering
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering,
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = {
'bert': BertForQuestionAnswering,
'xlnet': XLNetForQuestionAnswering,
'xlm': XLMForQuestionAnswering,
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def train(args, train_dataset, model):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
num_train_optimization_steps = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate,
t_total=num_train_optimization_steps, warmup=args.warmup_proportion)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Total batch size (distributed) = %d", args.train_batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", num_train_optimization_steps)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.train()
optimizer.zero_grad()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0]
def evalutate(args, dataset, model):
""" Evaluate the model """
def load_and_cache_examples(args, tokenizer, training=True):
""" Load data features from cache or dataset file. """
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length),
str(task)))
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
examples = read_squad_examples(input_file=args.train_file if training else args.predict_file,
is_training=training,
version_2_with_negative=args.version_2_with_negative)
features = convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=training)
if args.local_rank in [-1, 0]:
logger.info("Num orig examples = %d", len(examples))
logger.info("Num split examples = %d", len(features))
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if training:
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
else:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
return dataset
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--train_file", default=None, type=str, required=True,
help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str, required=True,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
## Other parameters
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument('--version_2_with_negative', action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
@ -71,65 +191,53 @@ def main():
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_predict", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument("--predict_batch_size", default=8, type=int,
help="Total batch size for predictions.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json "
"output file.")
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
action='store_true',
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--seed',
type=int,
default=42,
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank",
type=int,
default=-1,
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument('--null_score_diff_threshold',
type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
print(args)
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
@ -137,71 +245,52 @@ def main():
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
args.n_gpu = 1
args.device = device
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
# Setup logging
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Setup seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
if args.do_train:
if not args.train_file:
raise ValueError(
"If `do_train` is True, then `train_file` must be specified.")
if args.do_predict:
if not args.predict_file:
raise ValueError(
"If `do_predict` is True, then `predict_file` must be specified.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory {} already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() # Make sure only 1st process in distributed training download model & vocab
args.model_type = args.model_name.lower().split('-')[0]
tokenizer_class = TOKENIZER_CLASSES[args.model_type]
model_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()
model.to(device)
# Distributed and parrallel training
model.to(args.device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif n_gpu > 1:
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Training
if args.do_train:
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()

View File

@ -1,530 +0,0 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import os
import sys
import random
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss
from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_xlnet import XLNetForSequenceClassification
from pytorch_transformers.tokenization_xlnet import XLNetTokenizer
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
# training
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0 limit the number of training steps to perform, you should choose only one of num_train_epochs and max_steps.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--clip_gradients", default=1.0, type=float,
help="Clip gradient norms.")
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--log_every", default=10, type=int,
help="Log metrics every X training steps.")
# evaluation
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
# Model
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
# task specific
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
# Misc
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
args.device = device
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
task_name = args.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = XLNetTokenizer.from_pretrained(args.xlnet_model, do_lower_case=args.do_lower_case)
model = XLNetForSequenceClassification.from_pretrained(args.xlnet_model, num_labels=num_labels)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
global_step = 0
curr_tr_loss, curr_steps = 0., 1
if args.do_train:
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
# Prepare data loader
train_examples = processor.get_train_examples(args.data_dir)
cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
if os.path.exists(cached_train_features_file):
logger.info("Loading train features for cache file %s", cached_train_features_file)
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
else:
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = SequentialSampler(train_data) # RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
num_train_optimization_steps = args.max_steps
else:
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer
optimizer_grouped_parameters = model.parameters()
# param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
# {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
# {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
model.train()
for _ in trange(int(args.num_train_epochs) if args.max_steps <= 0 else int('Inf'),
desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader,
desc="Iteration",
disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
loss, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
curr_tr_loss += loss.item()
curr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
if args.local_rank in [-1, 0] and (args.log_every <= 0 or (global_step + 1) % args.log_every == 0):
learning_rate = optimizer.get_lr()[0] if not args.fp16 else lr_this_step
logger.info("[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f}".format(
global_step, gnorm, learning_rate, curr_tr_loss / curr_steps))
tb_writer.add_scalar('lr', learning_rate, global_step)
tb_writer.add_scalar('loss', curr_tr_loss / curr_steps, global_step)
curr_tr_loss, curr_steps = 0., 1
if args.max_steps > 0 and global_step > args.max_steps:
break
if args.max_steps > 0 and global_step > args.max_steps:
break
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
### Example:
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned
model = XLNetForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = XLNetTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
# Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = XLNetForSequenceClassification.from_pretrained(args.xlnet_model, num_labels=num_labels)
model.to(device)
### Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
list(filter(None, args.xlnet_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
if os.path.exists(cached_eval_features_file):
logger.info("Loading eval features for cache file %s", cached_eval_features_file)
with open(cached_eval_features_file, "rb") as reader:
eval_features = pickle.load(reader)
else:
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
with open(cached_eval_features_file, "wb") as writer:
pickle.dump(eval_features, writer)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
out_label_ids = None
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
out_label_ids = label_ids.detach().cpu().numpy()
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
if output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, out_label_ids)
loss = curr_tr_loss/curr_steps if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
# hack for MNLI-MM
if task_name == "mnli":
task_name = "mnli-mm"
processor = processors[task_name]()
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir + '-MM'):
os.makedirs(args.output_dir + '-MM')
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
out_label_ids = None
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
out_label_ids = label_ids.detach().cpu().numpy()
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, out_label_ids)
loss = curr_tr_loss/curr_steps if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
main()

View File

@ -21,6 +21,7 @@ import csv
import logging
import os
import sys
from io import open
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score

View File

@ -36,7 +36,7 @@ from .modeling_xlm import (XLMConfig, XLMModel,
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)

View File

@ -73,17 +73,17 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
@ -93,7 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
print("Skipping {}".format("/".join(name)))
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
@ -113,7 +113,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try:
pointer = getattr(pointer, l[0])
except AttributeError:
print("Skipping {}".format("/".join(name)))
logger.info("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
@ -127,7 +127,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model

View File

@ -49,17 +49,17 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(gpt2_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array.squeeze())
@ -90,7 +90,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model

View File

@ -110,7 +110,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model

View File

@ -126,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
# Build TF to PyTorch weights loading map
@ -136,7 +136,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
@ -157,7 +157,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
@ -165,13 +165,13 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model

View File

@ -272,7 +272,6 @@ class LogUniformSampler(object):
self.range_max = range_max
log_indices = torch.arange(1., range_max+2., 1.).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# print('P', self.dist.numpy().tolist()[-30:])
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
@ -331,72 +330,3 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
return logits
# class LogUniformSampler(object):
# def __init__(self, range_max, unique=False):
# """
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
# """
# self.range_max = range_max
# log_indices = torch.arange(1., range_max+2., 1.).log_()
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# self.unique = unique
# if self.unique:
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
# def sample(self, n_sample, labels):
# pos_sample, new_labels = labels.unique(return_inverse=True)
# n_pos_sample = pos_sample.size(0)
# n_neg_sample = n_sample - n_pos_sample
# if self.unique:
# self.exclude_mask.index_fill_(0, pos_sample, 1)
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
# self.exclude_mask.index_fill_(0, pos_sample, 0)
# else:
# sample_dist = self.dist
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
# sample = torch.cat([pos_sample, neg_sample])
# sample_prob = self.dist[sample]
# return new_labels, sample, sample_prob
if __name__ == '__main__':
S, B = 3, 4
n_vocab = 10000
n_sample = 5
H = 32
labels = torch.LongTensor(S, B).random_(0, n_vocab)
# sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist())
# print('samp_probs', samp_probs.numpy().tolist())
# print('neg_samples', neg_samples.numpy().tolist())
# print('sum', torch.sum(sampler.dist).item())
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
embedding = nn.Embedding(n_vocab, H)
bias = torch.zeros(n_vocab)
inputs = torch.Tensor(S, B, H).normal_()
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
print('logits', logits.detach().numpy().tolist())
print('logits shape', logits.size())
print('out_labels', out_labels.detach().numpy().tolist())
print('out_labels shape', out_labels.size())

View File

@ -57,16 +57,18 @@ class PretrainedConfig(object):
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
- a path or url to a directory containing a configuration file `config.json` for the model,
- a path or url to a configuration file for the model.
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
"""
cache_dir = kwargs.pop('cache_dir', None)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else:
elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
@ -102,7 +104,7 @@ class PretrainedConfig(object):
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
logger.info("Model config %s", config)
return config
@classmethod
@ -200,6 +202,7 @@ class PreTrainedModel(nn.Module):
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
config: an optional configuration for the model
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
@ -207,23 +210,31 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
else:
elif os.path.isdir(pretrained_model_name_or_path):
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)

View File

@ -122,14 +122,14 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
@ -137,15 +137,15 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
for name, pointer in tf_to_pt_map.items():
print("Importing {}".format(name))
logger.info("Importing {}".format(name))
if name not in tf_weights:
print("{} not in tf pre-trained weights, skipping".format(name))
logger.info("{} not in tf pre-trained weights, skipping".format(name))
continue
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name and ('ff' in name or 'summary' in name or 'logit' in name):
print("Transposing")
logger.info("Transposing")
array = np.transpose(array)
if isinstance(pointer, list):
# Here we will split the TF weigths
@ -157,7 +157,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
else:
try:
@ -165,13 +165,13 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
@ -211,10 +211,6 @@ class XLNetConfig(PretrainedConfig):
layer_norm_eps=1e-12,
dropout=0.1,
dropatt=0.1,
init="normal",
init_range=0.1,
init_std=0.02,
mem_len=None,
reuse_len=None,
bi_data=False,
@ -258,11 +254,6 @@ class XLNetConfig(PretrainedConfig):
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
@ -297,11 +288,7 @@ class XLNetConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.init = init
self.init_range = init_range
self.init_std = init_std
self.dropout = dropout
self.dropatt = dropatt
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
@ -393,7 +380,7 @@ class XLNetRelativeAttention(nn.Module):
x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
# x = x[:, 0:klen, :, :]
x = torch.index_select(x, 1, torch.arange(klen))
x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
return x

View File

@ -14,174 +14,92 @@
# limitations under the License.
"""PyTorch optimization for BERT model."""
import logging
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
import abc
import sys
from torch.optim.lr_scheduler import LambdaLR
logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLRSchedule, self).__init__(optimizer, lambda x: x, last_epoch=last_epoch)
if sys.version_info >= (3, 4):
ABC = abc.ABC
else:
ABC = abc.ABCMeta('ABC', (), {})
class _LRSchedule(ABC):
""" Parent of all LRSchedules here. """
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
def __init__(self, warmup=0.002, t_total=-1, **kw):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super(_LRSchedule, self).__init__(**kw)
if t_total < 0:
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0.)
self.warmup, self.t_total = float(warmup), float(t_total)
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if self.t_total < 0:
return 1.
progress = float(step) / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
logger.warning(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.format(ret, self.__class__.__name__))
self.warned_for_t_total_at_progress = progress
# end warning
return ret
@abc.abstractmethod
def get_lr_(self, progress):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return 1.
class ConstantLR(_LRSchedule):
def get_lr_(self, progress):
return 1.
class WarmupCosineSchedule(_LRSchedule):
class WarmupCosineSchedule(LambdaLR):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
Linearly increases learning rate from 0 to 1 over `warmup` training steps.
Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
warn_t_total = True
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
self.cycles = cycles
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * cycles * 2 * progress))
super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
assert(cycles >= 1.)
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
return ret
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((cycles * progress) % 1)))
return ret
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
"""
All training progress is divided in `cycles` (default=1.) parts of equal length.
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
warmup = warmup * cycles if warmup >= 0 else warmup
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress):
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
else:
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * progress))
return ret
class WarmupConstantSchedule(_LRSchedule):
class WarmupConstantSchedule(LambdaLR):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Keeps learning rate equal to 1. after warmup.
"""
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return 1.
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
return 1.
super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupLinearSchedule(_LRSchedule):
class WarmupLinearSchedule(LambdaLR):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
"""
warn_t_total = True
def get_lr_(self, progress):
if progress < self.warmup:
return progress / self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return (t_total - step) / max(1, t_total - warmup_steps)
super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
SCHEDULES = {
None: ConstantLR,
"none": ConstantLR,
"warmup_cosine": WarmupCosineSchedule,
"warmup_constant": WarmupConstantSchedule,
"warmup_linear": WarmupLinearSchedule
}
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix.
Parameters:
lr: learning rate
@ -197,43 +115,20 @@ class BertAdam(Optimizer):
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
if lr is not required and lr < 0.0:
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, correct_bias=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1] ))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object
if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
"Please specify custom warmup and t_total in _LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
correct_bias=correct_bias)
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
@ -260,22 +155,28 @@ class BertAdam(Optimizer):
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
state['step'] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
@ -284,20 +185,8 @@ class BertAdam(Optimizer):
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Add weight decay at the end (fixed version)
if group['weight_decay'] > 0:
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
return loss

View File

@ -1,127 +0,0 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule
logger = logging.getLogger(__name__)
class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix.
"""
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
# initialize schedule object
if not isinstance(schedule, _LRSchedule):
schedule_type = SCHEDULES[schedule]
schedule = schedule_type(warmup=warmup, t_total=t_total)
else:
if warmup != -1 or t_total != -1:
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
"Please specify custom warmup and t_total in _LRSchedule object.")
defaults = dict(lr=lr, schedule=schedule,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['b1'], group['b2']
state['step'] += 1
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['e'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
lr_scheduled = group['lr']
lr_scheduled *= group['schedule'].get_lr(state['step'])
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Add weight decay at the end (fixed version)
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
return loss

View File

@ -20,10 +20,9 @@ import unittest
import torch
from pytorch_transformers import BertAdam
from pytorch_transformers import OpenAIAdam
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
import numpy as np
@ -34,12 +33,12 @@ class OptimizationTest(unittest.TestCase):
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
def test_adam(self):
def test_adam_w(self):
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = BertAdam(params=[w], lr=2e-1,
optimizer = AdamW(params=[w], lr=2e-1,
weight_decay=0.0,
max_grad_norm=-1)
for _ in range(100):
@ -52,23 +51,13 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
def test_bert_sched_init(self):
def test_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
def test_openai_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail

View File

@ -98,14 +98,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.build_vocab()
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
if verbose: logger.info('counting file {} ...'.format(path))
assert os.path.exists(path)
sents = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
logger.info(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols)
sents.append(symbols)
@ -116,10 +116,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
if verbose: logger.info('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
logger.info(' line {}'.format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
@ -147,11 +147,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
logger.info('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
logger.info('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
logger.info('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()
@ -163,18 +163,18 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if cnt < self.min_freq: break
self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format(
logger.info('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path))
if verbose: logger.info('encoding file {} ...'.format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
logger.info(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols))
@ -185,11 +185,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
if verbose: logger.info('encoding {} sents ...'.format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
logger.info(' line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))
if ordered:
@ -218,7 +218,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
# logger.info('encounter unk {}'.format(sym))
# assert '<eos>' not in sym
if hasattr(self, 'unk_idx'):
return self.sym2idx.get(sym, self.unk_idx)
@ -544,14 +544,14 @@ def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt')
fn_pickle = os.path.join(datadir, 'cache.pkl')
if os.path.exists(fn):
print('Loading cached dataset...')
logger.info('Loading cached dataset...')
corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
print('Loading cached dataset from pickle...')
logger.info('Loading cached dataset from pickle...')
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
print('Producing dataset {}...'.format(dataset))
logger.info('Producing dataset {}...'.format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']