diff --git a/extract_features_pytorch.py b/extract_features_pytorch.py index 80fff700f8d..5dfed69014e 100644 --- a/extract_features_pytorch.py +++ b/extract_features_pytorch.py @@ -316,7 +316,7 @@ def read_examples(input_file): return examples -def main(_): +def main(): tf.logging.set_verbosity(tf.logging.INFO) layer_indexes = [int(x) for x in args.layers.split(",")] @@ -387,4 +387,4 @@ def main(_): if __name__ == "__main__": - tf.app.run() + main() diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 0dc95269c9e..ca6e49f24d2 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -441,8 +441,8 @@ class BertForSequenceClassification(nn.Module): class BertForQuestionAnswering(nn.Module): """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with linear layers on top of - the sequence output. + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits Example usage: ```python @@ -455,7 +455,7 @@ class BertForQuestionAnswering(nn.Module): num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) model = BertForQuestionAnswering(config) - logits = model(input_ids, token_type_ids, input_mask) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 928103c2b29..04be6fd03d8 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import logging import json import math import os @@ -29,6 +30,14 @@ import six import tensorflow as tf import argparse +from modeling_pytorch import BertConfig, BertForQuestionAnswering +from optimization_pytorch import BERTAdam + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + parser = argparse.ArgumentParser() ## Required parameters @@ -94,6 +103,10 @@ parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if ` parser.add_argument("--verbose_logging", default=False, type=bool, help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") +parser.add_argument("--local_rank", + type=int, + default=-1, + help = "local_rank for distributed training on gpus") args = parser.parse_args() @@ -926,8 +939,15 @@ def _compute_softmax(scores): return probs -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) +def main(): + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + n_gpu = torch.cuda.device_count() + else: + device = torch.device("cuda", args.local_rank) + n_gpu = 1 + # print("Initializing the distributed backend: NCCL") + print("device", device, "n_gpu", n_gpu) if not args.do_train and not args.do_predict: raise ValueError("At least one of `do_train` or `do_predict` must be True.") @@ -941,7 +961,7 @@ def main(_): raise ValueError( "If `do_predict` is True, then `predict_file` must be specified.") - bert_config = modeling.BertConfig.from_json_file(args.bert_config_file) + bert_config = BertConfig.from_json_file(args.bert_config_file) if args.max_seq_length > bert_config.max_position_embeddings: raise ValueError( @@ -949,54 +969,69 @@ def main(_): "was only trained up to sequence length %d" % (args.max_seq_length, bert_config.max_position_embeddings)) - tf.gfile.MakeDirs(args.output_dir) + if os.path.exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError(f"Output directory ({args.output_dir}) already exists and is " + f"not empty.") + os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization.FullTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) - tpu_cluster_resolver = None - if args.use_tpu and args.tpu_name: - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - args.tpu_name, zone=args.tpu_zone, project=args.gcp_project) + # tpu_cluster_resolver = None + # if args.use_tpu and args.tpu_name: + # tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + # args.tpu_name, zone=args.tpu_zone, project=args.gcp_project) - is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.contrib.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=args.master, - model_dir=args.output_dir, - save_checkpoints_steps=args.save_checkpoints_steps, - tpu_config=tf.contrib.tpu.TPUConfig( - iterations_per_loop=args.iterations_per_loop, - num_shards=args.num_tpu_cores, - per_host_input_for_training=is_per_host)) + # is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 + # run_config = tf.contrib.tpu.RunConfig( + # cluster=tpu_cluster_resolver, + # master=args.master, + # model_dir=args.output_dir, + # save_checkpoints_steps=args.save_checkpoints_steps, + # tpu_config=tf.contrib.tpu.TPUConfig( + # iterations_per_loop=args.iterations_per_loop, + # num_shards=args.num_tpu_cores, + # per_host_input_for_training=is_per_host)) train_examples = None num_train_steps = None - num_warmup_steps = None + # num_warmup_steps = None if args.do_train: train_examples = read_squad_examples( input_file=args.train_file, is_training=True) num_train_steps = int( len(train_examples) / args.train_batch_size * args.num_train_epochs) - num_warmup_steps = int(num_train_steps * args.warmup_proportion) + # num_warmup_steps = int(num_train_steps * args.warmup_proportion) - model_fn = model_fn_builder( - bert_config=bert_config, - init_checkpoint=args.init_checkpoint, - learning_rate=args.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=args.use_tpu, - use_one_hot_embeddings=args.use_tpu) + # model_fn = model_fn_builder( + # bert_config=bert_config, + # init_checkpoint=args.init_checkpoint, + # learning_rate=args.learning_rate, + # num_train_steps=num_train_steps, + # num_warmup_steps=num_warmup_steps, + # use_tpu=args.use_tpu, + # use_one_hot_embeddings=args.use_tpu) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. - estimator = tf.contrib.tpu.TPUEstimator( - use_tpu=args.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=args.train_batch_size, - predict_batch_size=args.predict_batch_size) + # estimator = tf.contrib.tpu.TPUEstimator( + # use_tpu=args.use_tpu, + # model_fn=model_fn, + # config=run_config, + # train_batch_size=args.train_batch_size, + # predict_batch_size=args.predict_batch_size) + + model = BertForQuestionAnswering(bert_config) + if args.init_checkpoint is not None: + model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model.to(device) + + optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, + {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.} + ], + lr=args.learning_rate, schedule='warmup_linear', + warmup=args.warmup_proportion, + t_total=num_train_steps) if args.do_train: train_features = convert_examples_to_features( @@ -1067,4 +1102,4 @@ def main(_): if __name__ == "__main__": - tf.app.run() + main()