diff --git a/examples/run_squad.py b/examples/run_squad.py index 0c0fbf29636..8a9f123d20a 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""): write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.verbose_logging, - args.version_2_with_negative, args.null_score_diff_threshold) + args.version_2_with_negative, args.null_score_diff_threshold, tokenizer, args.model_type) # Evaluate with the official SQuAD script evaluate_options = EVAL_OPTS(data_file=args.predict_file, @@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, - is_training=not evaluate) + is_training=not evaluate, add_prefix_space=True if args.model_type == 'roberta' else False) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) diff --git a/examples/utils_squad.py b/examples/utils_squad.py index b990ecc8420..82a4b96b79e 100644 --- a/examples/utils_squad.py +++ b/examples/utils_squad.py @@ -25,6 +25,7 @@ import collections from io import open from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize +from transformers.tokenization_roberta import RobertaTokenizer # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores @@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, cls_token='[CLS]', sep_token='[SEP]', pad_token=0, sequence_a_segment_id=0, sequence_b_segment_id=1, cls_token_segment_id=0, pad_token_segment_id=0, - mask_padding_with_zero=True): + mask_padding_with_zero=True, add_prefix_space=False): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 @@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, # if example_index % 100 == 0: # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg) - - query_tokens = tokenizer.tokenize(example.question_text) + query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=add_prefix_space) if len(query_tokens) > max_query_length: query_tokens = query_tokens[0:max_query_length] @@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, all_doc_tokens = [] for (i, token) in enumerate(example.doc_tokens): orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer.tokenize(token) + sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space) for sub_token in sub_tokens: tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) @@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, tok_end_position = len(all_doc_tokens) - 1 (tok_start_position, tok_end_position) = _improve_answer_span( all_doc_tokens, tok_start_position, tok_end_position, tokenizer, - example.orig_answer_text) + example.orig_answer_text, add_prefix_space) # The -3 accounts for [CLS], [SEP] and [SEP] max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 @@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, - orig_answer_text): + orig_answer_text, add_prefix_space): """Returns tokenized answer spans that better match the annotated answer.""" # The SQuAD annotations are character based. We first project them to @@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, # the word "Japanese". Since our WordPiece tokenizer does not split # "Japanese", we just use "Japanese" as the annotation. This is fairly rare # in SQuAD, but does happen. - tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) + tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=add_prefix_space)) for new_start in range(input_start, input_end + 1): for new_end in range(input_end, new_start - 1, -1): @@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult", def write_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, verbose_logging, - version_2_with_negative, null_score_diff_threshold): + version_2_with_negative, null_score_diff_threshold, tokenizer, mode_type='bert'): """Write final predictions to the json file and log-odds of null if needed.""" logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing nbest to: %s" % (output_nbest_file)) @@ -576,15 +576,22 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, tok_text = " ".join(tok_tokens) # De-tokenize WordPieces that have been split off. - tok_text = tok_text.replace(" ##", "") - tok_text = tok_text.replace("##", "") + if mode_type == 'roberta': + tok_text = tokenizer.convert_tokens_to_string(tok_tokens) + tok_text = tok_text.replace("##", "") + tok_text = " ".join(tok_text.strip().split()) + orig_text = " ".join(orig_tokens) + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, None) + else: + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") - # Clean whitespace - tok_text = tok_text.strip() - tok_text = " ".join(tok_text.split()) - orig_text = " ".join(orig_tokens) + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) - final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) if final_text in seen_predictions: continue