fixing for roberta tokenizer decoding

This commit is contained in:
erenup 2019-10-03 18:33:53 +08:00
parent ebb32261b1
commit 22e7c4edaf
2 changed files with 24 additions and 17 deletions

View File

@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer, args.model_type)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate)
is_training=not evaluate, add_prefix_space=True if args.model_type == 'roberta' else False)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

View File

@ -25,6 +25,7 @@ import collections
from io import open
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from transformers.tokenization_roberta import RobertaTokenizer
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True):
mask_padding_with_zero=True, add_prefix_space=False):
"""Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000
@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# if example_index % 100 == 0:
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
query_tokens = tokenizer.tokenize(example.question_text)
query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=add_prefix_space)
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text)
example.orig_answer_text, add_prefix_space)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text):
orig_answer_text, add_prefix_space):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=add_prefix_space))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult",
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, verbose_logging,
version_2_with_negative, null_score_diff_threshold):
version_2_with_negative, null_score_diff_threshold, tokenizer, mode_type='bert'):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file))
@ -576,15 +576,22 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
if mode_type == 'roberta':
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
tok_text = tok_text.replace("##", "")
tok_text = " ".join(tok_text.strip().split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, None)
else:
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions:
continue