# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run BERT on SQuAD.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import six import argparse import collections import logging import json import math import os from tqdm import tqdm, trange import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import tokenization_pytorch from modeling_pytorch import BertConfig, BertForQuestionAnswering from optimization_pytorch import BERTAdam logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) logger = logging.getLogger(__name__) class SquadExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, qas_id, question_text, doc_tokens, orig_answer_text=None, start_position=None, end_position=None): self.qas_id = qas_id self.question_text = question_text self.doc_tokens = doc_tokens self.orig_answer_text = orig_answer_text self.start_position = start_position self.end_position = end_position def __str__(self): return self.__repr__() def __repr__(self): s = "" s += "qas_id: %s" % (tokenization_pytorch.printable_text(self.qas_id)) s += ", question_text: %s" % ( tokenization_pytorch.printable_text(self.question_text)) s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) if self.start_position: s += ", start_position: %d" % (self.start_position) if self.start_position: s += ", end_position: %d" % (self.end_position) return s class InputFeatures(object): """A single set of features of data.""" def __init__(self, unique_id, example_index, doc_span_index, tokens, token_to_orig_map, token_is_max_context, input_ids, input_mask, segment_ids, start_position=None, end_position=None): self.unique_id = unique_id self.example_index = example_index self.doc_span_index = doc_span_index self.tokens = tokens self.token_to_orig_map = token_to_orig_map self.token_is_max_context = token_is_max_context self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.start_position = start_position self.end_position = end_position def read_squad_examples(input_file, is_training): """Read a SQuAD json file into a list of SquadExample.""" with open(input_file, "r") as reader: input_data = json.load(reader)["data"] def is_whitespace(c): if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: return True return False examples = [] for entry in input_data: for paragraph in entry["paragraphs"]: paragraph_text = paragraph["context"] doc_tokens = [] char_to_word_offset = [] prev_is_whitespace = True for c in paragraph_text: if is_whitespace(c): prev_is_whitespace = True else: if prev_is_whitespace: doc_tokens.append(c) else: doc_tokens[-1] += c prev_is_whitespace = False char_to_word_offset.append(len(doc_tokens) - 1) for qa in paragraph["qas"]: qas_id = qa["id"] question_text = qa["question"] start_position = None end_position = None orig_answer_text = None if is_training: if len(qa["answers"]) != 1: raise ValueError( "For training, each question should have exactly 1 answer.") answer = qa["answers"][0] orig_answer_text = answer["text"] answer_offset = answer["answer_start"] answer_length = len(orig_answer_text) start_position = char_to_word_offset[answer_offset] end_position = char_to_word_offset[answer_offset + answer_length - 1] # Only add answers where the text can be exactly recovered from the # document. If this CAN'T happen it's likely due to weird Unicode # stuff so we will just skip the example. # # Note that this means for training mode, every example is NOT # guaranteed to be preserved. actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) cleaned_answer_text = " ".join( tokenization_pytorch.whitespace_tokenize(orig_answer_text)) if actual_text.find(cleaned_answer_text) == -1: logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) continue example = SquadExample( qas_id=qas_id, question_text=question_text, doc_tokens=doc_tokens, orig_answer_text=orig_answer_text, start_position=start_position, end_position=end_position) examples.append(example) return examples def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 features = [] for (example_index, example) in enumerate(examples): query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: query_tokens = query_tokens[0:max_query_length] tok_to_orig_index = [] orig_to_tok_index = [] all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) sub_tokens = tokenizer.tokenize(token) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) tok_start_position = None tok_end_position = None if is_training: tok_start_position = orig_to_tok_index[example.start_position] if example.end_position < len(example.doc_tokens) - 1: tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 else: tok_end_position = len(all_doc_tokens) - 1 (tok_start_position, tok_end_position) = _improve_answer_span( all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.orig_answer_text) # The -3 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 # We can have documents that are longer than the maximum sequence length. # To deal with this we do a sliding window approach, where we take chunks # of the up to our max length with a stride of `doc_stride`. _DocSpan = collections.namedtuple( # pylint: disable=invalid-name "DocSpan", ["start", "length"]) doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): break start_offset += min(length, doc_stride) for (doc_span_index, doc_span) in enumerate(doc_spans): tokens = [] token_to_orig_map = {} token_is_max_context = {} segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in query_tokens: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) for i in range(doc_span.length): split_token_index = doc_span.start + i token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length start_position = None end_position = None if is_training: # For training, if our document chunk does not contain an annotation # we throw it out, since there is nothing to predict. doc_start = doc_span.start doc_end = doc_span.start + doc_span.length - 1 if (example.start_position < doc_start or example.end_position < doc_start or example.start_position > doc_end or example.end_position > doc_end): continue doc_offset = len(query_tokens) + 2 start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset if example_index < 20: logger.info("*** Example ***") logger.info("unique_id: %s" % (unique_id)) logger.info("example_index: %s" % (example_index)) logger.info("doc_span_index: %s" % (doc_span_index)) logger.info("tokens: %s" % " ".join( [tokenization_pytorch.printable_text(x) for x in tokens])) logger.info("token_to_orig_map: %s" % " ".join( ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) logger.info("token_is_max_context: %s" % " ".join([ "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) ])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info( "input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) if is_training: answer_text = " ".join(tokens[start_position:(end_position + 1)]) logger.info("start_position: %d" % (start_position)) logger.info("end_position: %d" % (end_position)) logger.info( "answer: %s" % (tokenization_pytorch.printable_text(answer_text))) features.append( InputFeatures( unique_id=unique_id, example_index=example_index, doc_span_index=doc_span_index, tokens=tokens, token_to_orig_map=token_to_orig_map, token_is_max_context=token_is_max_context, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, start_position=start_position, end_position=end_position)) unique_id += 1 return features def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): """Returns tokenized answer spans that better match the annotated answer.""" # The SQuAD annotations are character based. We first project them to # whitespace-tokenized words. But then after WordPiece tokenization, we can # often find a "better match". For example: # # Question: What year was John Smith born? # Context: The leader was John Smith (1895-1943). # Answer: 1895 # # The original whitespace-tokenized answer will be "(1895-1943).". However # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match # the exact answer, 1895. # # However, this is not always possible. Consider the following: # # Question: What country is the top exporter of electornics? # Context: The Japanese electronics industry is the lagest in the world. # Answer: Japan # # In this case, the annotator chose "Japan" as a character sub-span of # the word "Japanese". Since our WordPiece tokenizer does not split # "Japanese", we just use "Japanese" as the annotation. This is fairly rare # in SQuAD, but does happen. tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) for new_start in range(input_start, input_end + 1): for new_end in range(input_end, new_start - 1, -1): text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) if text_span == tok_answer_text: return (new_start, new_end) return (input_start, input_end) def _check_is_max_context(doc_spans, cur_span_index, position): """Check if this is the 'max context' doc span for the token.""" # Because of the sliding window approach taken to scoring documents, a single # token can appear in multiple documents. E.g. # Doc: the man went to the store and bought a gallon of milk # Span A: the man went to the # Span B: to the store and bought # Span C: and bought a gallon of # ... # # Now the word 'bought' will have two scores from spans B and C. We only # want to consider the score with "maximum context", which we define as # the *minimum* of its left and right context (the *sum* of left and # right context will always be the same, of course). # # In the example the maximum context for 'bought' would be span C since # it has 1 left context and 3 right context, while span B has 4 left context # and 0 right context. best_score = None best_span_index = None for (span_index, doc_span) in enumerate(doc_spans): end = doc_span.start + doc_span.length - 1 if position < doc_span.start: continue if position > end: continue num_left_context = position - doc_span.start num_right_context = end - position score = min(num_left_context, num_right_context) + 0.01 * doc_span.length if best_score is None or score > best_score: best_score = score best_span_index = span_index return cur_span_index == best_span_index RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) def write_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, output_nbest_file): """Write final predictions to the json file.""" logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing nbest to: %s" % (output_nbest_file)) example_index_to_features = collections.defaultdict(list) for feature in all_features: example_index_to_features[feature.example_index].append(feature) unique_id_to_result = {} for result in all_results: unique_id_to_result[result.unique_id] = result _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() for (example_index, example) in enumerate(all_examples): features = example_index_to_features[example_index] prelim_predictions = [] for (feature_index, feature) in enumerate(features): result = unique_id_to_result[feature.unique_id] start_indexes = _get_best_indexes(result.start_logits, n_best_size) end_indexes = _get_best_indexes(result.end_logits, n_best_size) for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict # that the start of the span is in the question. We throw out all # invalid predictions. if start_index >= len(feature.tokens): continue if end_index >= len(feature.tokens): continue if start_index not in feature.token_to_orig_map: continue if end_index not in feature.token_to_orig_map: continue if not feature.token_is_max_context.get(start_index, False): continue if end_index < start_index: continue length = end_index - start_index + 1 if length > max_answer_length: continue prelim_predictions.append( _PrelimPrediction( feature_index=feature_index, start_index=start_index, end_index=end_index, start_logit=result.start_logits[start_index], end_logit=result.end_logits[end_index])) prelim_predictions = sorted( prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name "NbestPrediction", ["text", "start_logit", "end_logit"]) seen_predictions = {} nbest = [] for pred in prelim_predictions: if len(nbest) >= n_best_size: break feature = features[pred.feature_index] tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] tok_text = " ".join(tok_tokens) # De-tokenize WordPieces that have been split off. tok_text = tok_text.replace(" ##", "") tok_text = tok_text.replace("##", "") # Clean whitespace tok_text = tok_text.strip() tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) final_text = get_final_text(tok_text, orig_text, do_lower_case) if final_text in seen_predictions: continue seen_predictions[final_text] = True nbest.append( _NbestPrediction( text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. if not nbest: nbest.append( _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) assert len(nbest) >= 1 total_scores = [] for entry in nbest: total_scores.append(entry.start_logit + entry.end_logit) probs = _compute_softmax(total_scores) nbest_json = [] for (i, entry) in enumerate(nbest): output = collections.OrderedDict() output["text"] = entry.text output["probability"] = probs[i] output["start_logit"] = entry.start_logit output["end_logit"] = entry.end_logit nbest_json.append(output) assert len(nbest_json) >= 1 all_predictions[example.qas_id] = nbest_json[0]["text"] all_nbest_json[example.qas_id] = nbest_json with open(output_prediction_file, "w") as writer: writer.write(json.dumps(all_predictions, indent=4) + "\n") with open(output_nbest_file, "w") as writer: writer.write(json.dumps(all_nbest_json, indent=4) + "\n") def get_final_text(pred_text, orig_text, do_lower_case): """Project the tokenized prediction back to the original text.""" # When we created the data, we kept track of the alignment between original # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So # now `orig_text` contains the span of our original text corresponding to the # span that we predicted. # # However, `orig_text` may contain extra characters that we don't want in # our prediction. # # For example, let's say: # pred_text = steve smith # orig_text = Steve Smith's # # We don't want to return `orig_text` because it contains the extra "'s". # # We don't want to return `pred_text` because it's already been normalized # (the SQuAD eval script also does punctuation stripping/lower casing but # our tokenizer does additional normalization like stripping accent # characters). # # What we really want to return is "Steve Smith". # # Therefore, we have to apply a semi-complicated alignment heruistic between # `pred_text` and `orig_text` to get a character-to-charcter alignment. This # can fail in certain cases in which case we just return `orig_text`. def _strip_spaces(text): ns_chars = [] ns_to_s_map = collections.OrderedDict() for (i, c) in enumerate(text): if c == " ": continue ns_to_s_map[len(ns_chars)] = i ns_chars.append(c) ns_text = "".join(ns_chars) return (ns_text, ns_to_s_map) # We first tokenize `orig_text`, strip whitespace from the result # and `pred_text`, and check if they are the same length. If they are # NOT the same length, the heuristic has failed. If they are the same # length, we assume the characters are one-to-one aligned. tokenizer = tokenization_pytorch.BasicTokenizer(do_lower_case=do_lower_case) tok_text = " ".join(tokenizer.tokenize(orig_text)) start_position = tok_text.find(pred_text) if start_position == -1: if args.verbose_logging: logger.info( "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) return orig_text end_position = start_position + len(pred_text) - 1 (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) if len(orig_ns_text) != len(tok_ns_text): if args.verbose_logging: logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) return orig_text # We then project the characters in `pred_text` back to `orig_text` using # the character-to-character alignment. tok_s_to_ns_map = {} for (i, tok_index) in six.iteritems(tok_ns_to_s_map): tok_s_to_ns_map[tok_index] = i orig_start_position = None if start_position in tok_s_to_ns_map: ns_start_position = tok_s_to_ns_map[start_position] if ns_start_position in orig_ns_to_s_map: orig_start_position = orig_ns_to_s_map[ns_start_position] if orig_start_position is None: if args.verbose_logging: logger.info("Couldn't map start position") return orig_text orig_end_position = None if end_position in tok_s_to_ns_map: ns_end_position = tok_s_to_ns_map[end_position] if ns_end_position in orig_ns_to_s_map: orig_end_position = orig_ns_to_s_map[ns_end_position] if orig_end_position is None: if args.verbose_logging: logger.info("Couldn't map end position") return orig_text output_text = orig_text[orig_start_position:(orig_end_position + 1)] return output_text def _get_best_indexes(logits, n_best_size): """Get the n-best logits from a list.""" index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) best_indexes = [] for i in range(len(index_and_score)): if i >= n_best_size: break best_indexes.append(index_and_score[i][0]) return best_indexes def _compute_softmax(scores): """Compute softmax probability over raw logits.""" if not scores: return [] max_score = None for score in scores: if max_score is None or score > max_score: max_score = score exp_scores = [] total_sum = 0.0 for score in scores: x = math.exp(score - max_score) exp_scores.append(x) total_sum += x probs = [] for score in exp_scores: probs.append(score / total_sum) return probs def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--bert_config_file", default=None, type=str, required=True, help="The config json file corresponding to the pre-trained BERT model. " "This specifies the model architecture.") parser.add_argument("--vocab_file", default=None, type=str, required=True, help="The vocabulary file that the BERT model was trained on.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written.") ## Other parameters parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument("--predict_file", default=None, type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument("--init_checkpoint", default=None, type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") parser.add_argument("--do_lower_case", default=True, action='store_true', help="Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument("--max_query_length", default=64, type=int, help="The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " "of training.") parser.add_argument("--save_checkpoints_steps", default=1000, type=int, help="How often to save the model checkpoint.") parser.add_argument("--iterations_per_loop", default=1000, type=int, help="How many steps to make in each estimator call.") parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") ### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ### # parser.add_argument("--use_tpu", default=False, action='store_true', help="Whether to use TPU or GPU/CPU.") # parser.add_argument("--tpu_name", default=None, type=str, # help="The Cloud TPU to use for training. This should be either the name used when creating the " # "Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") # parser.add_argument("--tpu_zone", default=None, type=str, # help="[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt " # "to automatically detect the GCE project from metadata.") # parser.add_argument("--gcp_project", default=None, type=str, # help="[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt " # "to automatically detect the GCE project from metadata.") # parser.add_argument("--master", default=None, type=str, help="[Optional] TensorFlow master URL.") # parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `use_tpu` is True. " # "Total number of TPU cores to use.") ### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ### parser.add_argument("--verbose_logging", default=False, action='store_true', help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: device = torch.device("cuda", args.local_rank) n_gpu = 1 # print("Initializing the distributed backend: NCCL") print("device", device, "n_gpu", n_gpu) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu>0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict: raise ValueError("At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified.") bert_config = BertConfig.from_json_file(args.bert_config_file) if args.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (args.max_seq_length, bert_config.max_position_embeddings)) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory () already exists and is not empty.") os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization_pytorch.FullTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) train_examples = None num_train_steps = None if args.do_train: train_examples = read_squad_examples( input_file=args.train_file, is_training=True) num_train_steps = int( len(train_examples) / args.train_batch_size * args.num_train_epochs) model = BertForQuestionAnswering(bert_config) if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [ {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} ] optimizer = BERTAdam(optimizer_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) global_step = 0 if args.do_train: train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) logger.info("HHHHH Loading data") all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) #all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) logger.info("HHHHH Creating dataset") #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) logger.info("HHHHH Dataloader") train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) logger.info("HHHHH Starting Traing") model.train() for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader, desc="Iteration"): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device) #label_ids = label_ids.to(device) start_positions = start_positions.to(device) end_positions = start_positions.to(device) start_positions = start_positions.view(-1, 1) end_positions = end_positions.view(-1, 1) logger.info("HHHHH Forward") loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) model.zero_grad() logger.info("HHHHH Backward") loss.backward() logger.info("HHHHH Loading data") optimizer.step() global_step += 1 logger.info("Done %s steps", global_step) if args.do_predict: eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) #all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) #eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) if args.local_rank == -1: eval_sampler = SequentialSampler(eval_data) else: eval_sampler = DistributedSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] logger.info("Start evaulating") #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: for input_ids, input_mask, segment_ids, example_index in eval_dataloader: if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device) start_logits, end_logits = model(input_ids, segment_ids, input_mask) unique_id = [int(eval_features[e.item()].unique_id) for e in example_index] #start_logits = [x.item() for x in start_logits] start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] #end_logits = [x.item() for x in end_logits] end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] for idx, i in enumerate(unique_id): s = [float(x) for x in start_logits[idx]] e = [float(x) for x in end_logits[idx]] all_results.append( RawResult( unique_id=i, start_logits=s, end_logits=e ) ) # all_results.append( # RawResult( # unique_id=unique_id, # start_logits=start_logits, # end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file) if __name__ == "__main__": main()