various updates to conversion, models and examples

This commit is contained in:
thomwolf 2019-06-26 00:57:53 +02:00
parent 603c513b35
commit e55d4c4ede
6 changed files with 44 additions and 21 deletions

View File

@ -1394,7 +1394,7 @@ The data for SWAG can be downloaded by cloning the following [repository](https:
```shell ```shell
export SWAG_DIR=/path/to/SWAG export SWAG_DIR=/path/to/SWAG
python run_swag.py \ python run_bert_swag.py \
--bert_model bert-base-uncased \ --bert_model bert-base-uncased \
--do_train \ --do_train \
--do_lower_case \ --do_lower_case \
@ -1581,7 +1581,6 @@ python run_xlnet_classifier.py \
--task_name STS-B \ --task_name STS-B \
--do_train \ --do_train \
--do_eval \ --do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/STS-B/ \ --data_dir $GLUE_DIR/STS-B/ \
--max_seq_length 128 \ --max_seq_length 128 \
--train_batch_size 8 \ --train_batch_size 8 \

View File

@ -70,6 +70,8 @@ def main():
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. " help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.") "E.g., 0.1 = 10%% of training.")
parser.add_argument("--clip_gradients", default=1.0, type=float,
help="Clip gradient norms.")
parser.add_argument("--train_batch_size", default=32, type=int, parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.") help="Total batch size for training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
@ -80,6 +82,8 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument("--log_every", default=10, type=int,
help="Log metrics every X training steps.")
# evaluation # evaluation
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.") help="Whether to run eval on the dev set.")
@ -234,12 +238,13 @@ def main():
# Prepare optimizer # Prepare optimizer
param_optimizer = list(model.named_parameters()) optimizer_grouped_parameters = model.parameters()
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] # param_optimizer = list(model.named_parameters())
optimizer_grouped_parameters = [ # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, # optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
] # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizers import FP16_Optimizer
@ -297,6 +302,9 @@ def main():
else: else:
loss.backward() loss.backward()
if args.clip_gradients > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients)
tr_loss += loss.item() tr_loss += loss.item()
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
@ -310,7 +318,7 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0):
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step) tb_writer.add_scalar('loss', loss.item(), global_step)

View File

@ -20,6 +20,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig, from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet)
from .optimization import BertAdam from .optimization import BertAdam

View File

@ -28,20 +28,31 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetForSequenceClassification, XLNetForSequenceClassification,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet)
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "sst-2", "sts-b", "qqp", "qnli", "rte", "wnli"] GLUE_TASKS = {
"cola": "classification",
"mnli": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
}
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
# Initialise PyTorch model # Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
if finetuning_task is not None and finetuning_task.lower() in GLUE_TASKS:
model_class = XLNetLMHeadModel finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
elif finetuning_task is not None and 'squad' in finetuning_task.lower(): if finetuning_task in GLUE_TASKS:
model_class = XLNetForQuestionAnswering print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
model = XLNetForSequenceClassification(config, is_regression=bool(GLUE_TASKS[finetuning_task] == "regression"))
elif 'squad' in finetuning_task:
model = XLNetForQuestionAnswering(config)
else: else:
model_class = XLNetLMHeadModel model = XLNetLMHeadModel(config)
print("Building PyTorch model {} from configuration: {}".format(str(model_class), str(config)))
model = model_class(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task) load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task)
@ -80,6 +91,8 @@ if __name__ == "__main__":
type = str, type = str,
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
args = parser.parse_args() args = parser.parse_args()
print(args)
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xlnet_config_file, args.xlnet_config_file,
args.pytorch_dump_folder_path, args.pytorch_dump_folder_path,

View File

@ -30,7 +30,7 @@ from io import open
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
@ -58,11 +58,11 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
if hasattr(model, 'lm_loss'): if hasattr(model, 'lm_loss'):
# We will load also the output bias # We will load also the output bias
tf_to_pt_map['model/lm_loss/bias'] = model.lm_loss.bias tf_to_pt_map['model/lm_loss/bias'] = model.lm_loss.bias
elif hasattr(model, 'sequence_summary') and 'model/sequnece_summary/summary/kernel' in tf_weights: if hasattr(model, 'sequence_summary') and 'model/sequnece_summary/summary/kernel' in tf_weights:
# We will load also the sequence summary # We will load also the sequence summary
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
elif hasattr(model, 'logits_proj') and finetuning_task is not None and any('model/regression' in name for name in tf_weights.keys()): if hasattr(model, 'logits_proj') and finetuning_task is not None and 'model/regression_{}/logit/kernel'.format(finetuning_task) in tf_weights:
tf_to_pt_map['model/regression_{}/logit/kernel'.format(finetuning_task)] = model.logits_proj.weight tf_to_pt_map['model/regression_{}/logit/kernel'.format(finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(finetuning_task)] = model.logits_proj.bias tf_to_pt_map['model/regression_{}/logit/bias'.format(finetuning_task)] = model.logits_proj.bias
@ -133,6 +133,8 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
array = tf.train.load_variable(tf_path, name) array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array tf_weights[name] = array
input("Press Enter to continue...")
# Build TF to PyTorch weights loading map # Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights, finetuning_task) tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights, finetuning_task)
@ -144,7 +146,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
array = tf_weights[name] array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if 'kernel' in name and 'ff' in name: if 'kernel' in name and ('ff' in name or 'summary' in name or 'logit' in name):
print("Transposing") print("Transposing")
array = np.transpose(array) array = np.transpose(array)
if isinstance(pointer, list): if isinstance(pointer, list):