From c3ba6452377f085d0f59e15b97ac247bca24367e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 22 Nov 2019 14:36:49 -0500 Subject: [PATCH] Works for XLNet --- examples/run_squad.py | 38 ++++-------- transformers/data/processors/squad.py | 84 +++++++++++++-------------- 2 files changed, 50 insertions(+), 72 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index d4219c3096c..634b566a463 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -16,6 +16,7 @@ """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" from __future__ import absolute_import, division, print_function +from transformers.data.processors.squad import SquadV1Processor import argparse import logging @@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig, from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples -from utils_squad import (read_squad_examples, convert_examples_to_features, - RawResult, write_predictions, +from utils_squad import (RawResult, write_predictions, RawResultExtended, write_predictions_extended) # The follwing import is the official SQuAD evaluation script (2.0). @@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""): results = evaluate_on_squad(evaluate_options) return results - def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): if args.local_rank not in [-1, 0] and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache @@ -308,9 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal examples = read_squad_examples(input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative) - - examples = examples[:10] - features = convert_examples_to_features(examples=examples, + keep_n_examples = 1000 + processor = SquadV1Processor() + values = processor.get_dev_examples("examples/squad") + examples = values[:keep_n_examples] + features = squad_convert_examples_to_features(examples=exampless, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, @@ -320,29 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0, cls_token_at_end=True if args.model_type in ['xlnet'] else False, sequence_a_is_doc=True if args.model_type in ['xlnet'] else False) - - exampless = sread_squad_examples(input_file=input_file, - is_training=not evaluate, - version_2_with_negative=args.version_2_with_negative) - exampless = exampless[:10] - features2 = squad_convert_examples_to_features(examples=exampless, - tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - doc_stride=args.doc_stride, - max_query_length=args.max_query_length, - is_training=not evaluate, - cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, - pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0, - cls_token_at_end=True if args.model_type in ['xlnet'] else False, - sequence_a_is_doc=True if args.model_type in ['xlnet'] else False) - - print(features2) - - for i in range(len(features)): - assert features[i] == features2[i] - print("Equal") - print("DONE") + + import sys + sys.exit() if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index a0f2408a169..fb3d2ae4d42 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, sequence_a_is_doc=False): """Loads a data file into a list of `InputBatch`s.""" + cls_token = tokenizer.cls_token + sep_token = tokenizer.sep_token + # Defining helper methods unique_id = 1000000000 @@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair encoded_dict = tokenizer.encode_plus( - truncated_query, - all_doc_tokens, + truncated_query if not sequence_a_is_doc else all_doc_tokens, + all_doc_tokens if not sequence_a_is_doc else truncated_query, max_length=max_seq_length, padding_strategy='right', stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, return_overflowing_tokens=True, - truncation_strategy='only_second' + truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first' ) ids = encoded_dict['input_ids'] - print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None) non_padded_ids = ids[:ids.index(tokenizer.pad_token_id)] if tokenizer.pad_token_id in ids else ids paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) token_to_orig_map = {} for i in range(paragraph_len): - token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[0 + i] + index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i + token_to_orig_map[index] = tok_to_orig_index[0 + i] encoded_dict["paragraph_len"] = paragraph_len encoded_dict["tokens"] = tokens @@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, encoded_dict["length"] = paragraph_len spans.append(encoded_dict) - print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict) + # print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict) + while len(spans) * doc_stride < len(all_doc_tokens) and "overflowing_tokens" in encoded_dict: - - overflowing_tokens = encoded_dict['overflowing_tokens'] - - print("OVERFLOW", len(overflowing_tokens)) - + overflowing_tokens = encoded_dict["overflowing_tokens"] encoded_dict = tokenizer.encode_plus( - truncated_query, - overflowing_tokens, + truncated_query if not sequence_a_is_doc else overflowing_tokens, + overflowing_tokens if not sequence_a_is_doc else truncated_query, max_length=max_seq_length, return_overflowing_tokens=True, padding_strategy='right', stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, - truncation_strategy='only_second' + truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first' ) - ids = encoded_dict['input_ids'] - print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None) + # print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None) + + # print(encoded_dict["input_ids"].index(tokenizer.pad_token_id) if tokenizer.pad_token_id in encoded_dict["input_ids"] else None) + # print(len(spans) * doc_stride, len(all_doc_tokens)) + # Length of the document without the query paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) - non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] + if tokenizer.pad_token_id in encoded_dict['input_ids']: + non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] + else: + non_padded_ids = encoded_dict['input_ids'] + tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) token_to_orig_map = {} for i in range(paragraph_len): - token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[len(spans) * doc_stride + i] + index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i + token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] encoded_dict["paragraph_len"] = paragraph_len encoded_dict["tokens"] = tokens @@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, encoded_dict["start"] = len(spans) * doc_stride encoded_dict["length"] = paragraph_len - # split_token_index = doc_span.start + i - # token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] - - # is_max_context = _check_is_max_context(doc_spans, doc_span_index, - # split_token_index) - # token_is_max_context[len(tokens)] = is_max_context - # tokens.append(all_doc_tokens[split_token_index]) - spans.append(encoded_dict) for doc_span_index in range(len(spans)): for j in range(spans[doc_span_index]["paragraph_len"]): is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) - index = spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j + index = j if sequence_a_is_doc else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j spans[doc_span_index]["token_is_max_context"][index] = is_max_context - print("new span", len(spans)) for span in spans: # Identify the position of the CLS token cls_index = span['input_ids'].index(tokenizer.cls_token_id) @@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, # Original TF implem also keep the classification token (set to 0) (not sure why...) p_mask = np.array(span['token_type_ids']) - # Convert all SEP indices to '0' before inversion - p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 0 + p_mask = np.minimum(p_mask, 1) - # Limit positive values to one - p_mask = 1 - np.minimum(p_mask, 1) + if not sequence_a_is_doc: + # Limit positive values to one + p_mask = 1 - p_mask + + p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 # Set the CLS index to '0' p_mask[cls_index] = 0 - print("new features length", len(new_features)) - new_features.append(NewSquadFeatures( span['input_ids'], span['attention_mask'], @@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, doc_spans = [] start_offset = 0 while start_offset < len(all_doc_tokens): - print("OLD DOC CREATION BEGIN", start_offset, len(all_doc_tokens)) length = len(all_doc_tokens) - start_offset if length > max_tokens_for_doc: length = max_tokens_for_doc + # print("Start offset is", start_offset, len(all_doc_tokens), "length is", length) doc_spans.append(_DocSpan(start=start_offset, length=length)) if start_offset + length == len(all_doc_tokens): - print("Done with this doc span, breaking out.", start_offset, length) break - print("CHOOSING OFFSET", length, doc_stride) start_offset += min(length, doc_stride) - print("OLD DOC CREATION END", start_offset) - print("old span", len(doc_spans)) for (doc_span_index, doc_span) in enumerate(doc_spans): tokens = [] token_to_orig_map = {} @@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, input_mask.append(0 if mask_padding_with_zero else 1) segment_ids.append(pad_token_segment_id) p_mask.append(1) - print("[OLD] Ids computed; position of the first padding", input_ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in input_ids else None) + assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length @@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, # logger.info( # "answer: %s" % (answer_text)) - print("features length", len(features)) features.append( SquadFeatures( unique_id=unique_id, @@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, assert len(features) == len(new_features) for i in range(len(features)): - print(i) feature, new_feature = features[i], new_features[i] - input_ids = feature.input_ids + input_ids = [f if f not in [3,4,5] else 0 for f in feature.input_ids ] input_mask = feature.input_mask segment_ids = feature.segment_ids cls_index = feature.cls_index @@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, tokens = feature.tokens token_to_orig_map = feature.token_to_orig_map - new_input_ids = new_feature.input_ids + new_input_ids = [f if f not in [3,4,5] else 0 for f in new_feature.input_ids] new_input_mask = new_feature.attention_mask new_segment_ids = new_feature.token_type_ids new_cls_index = new_feature.cls_index @@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, assert example_index == new_example_index assert paragraph_len == new_paragraph_len assert token_is_max_context == new_token_is_max_context + + tokens = [t if tokenizer.convert_tokens_to_ids(t) is not tokenizer.unk_token_id else tokenizer.unk_token for t in tokens] + assert tokens == new_tokens assert token_to_orig_map == new_token_to_orig_map