mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixing for roberta tokenizer decoding
This commit is contained in:
parent
ebb32261b1
commit
22e7c4edaf
@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer, args.model_type)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate)
|
||||
is_training=not evaluate, add_prefix_space=True if args.model_type == 'roberta' else False)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
@ -25,6 +25,7 @@ import collections
|
||||
from io import open
|
||||
|
||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
from transformers.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
||||
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
|
||||
@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
mask_padding_with_zero=True, add_prefix_space=False):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
# if example_index % 100 == 0:
|
||||
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
|
||||
|
||||
query_tokens = tokenizer.tokenize(example.question_text)
|
||||
query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=add_prefix_space)
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
query_tokens = query_tokens[0:max_query_length]
|
||||
@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
all_doc_tokens = []
|
||||
for (i, token) in enumerate(example.doc_tokens):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
||||
example.orig_answer_text)
|
||||
example.orig_answer_text, add_prefix_space)
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
orig_answer_text, add_prefix_space):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
|
||||
# The SQuAD annotations are character based. We first project them to
|
||||
@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
# the word "Japanese". Since our WordPiece tokenizer does not split
|
||||
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
||||
# in SQuAD, but does happen.
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=add_prefix_space))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult",
|
||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||
version_2_with_negative, null_score_diff_threshold):
|
||||
version_2_with_negative, null_score_diff_threshold, tokenizer, mode_type='bert'):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
@ -576,15 +576,22 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
if mode_type == 'roberta':
|
||||
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||
tok_text = tok_text.replace("##", "")
|
||||
tok_text = " ".join(tok_text.strip().split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, None)
|
||||
else:
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user