Works for XLNet

This commit is contained in:
Lysandre 2019-11-22 14:36:49 -05:00 committed by LysandreJik
parent a5a8a6175f
commit c3ba645237
2 changed files with 50 additions and 72 deletions

View File

@ -16,6 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor
import argparse
import logging
@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
from utils_squad import (read_squad_examples, convert_examples_to_features,
RawResult, write_predictions,
from utils_squad import (RawResult, write_predictions,
RawResultExtended, write_predictions_extended)
# The follwing import is the official SQuAD evaluation script (2.0).
@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
results = evaluate_on_squad(evaluate_options)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
@ -308,9 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
examples = examples[:10]
features = convert_examples_to_features(examples=examples,
keep_n_examples = 1000
processor = SquadV1Processor()
values = processor.get_dev_examples("examples/squad")
examples = values[:keep_n_examples]
features = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
@ -320,29 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
exampless = sread_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
exampless = exampless[:10]
features2 = squad_convert_examples_to_features(examples=exampless,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
print(features2)
for i in range(len(features)):
assert features[i] == features2[i]
print("Equal")
print("DONE")
import sys
sys.exit()
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)

View File

@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_a_is_doc=False):
"""Loads a data file into a list of `InputBatch`s."""
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
# Defining helper methods
unique_id = 1000000000
@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
encoded_dict = tokenizer.encode_plus(
truncated_query,
all_doc_tokens,
truncated_query if not sequence_a_is_doc else all_doc_tokens,
all_doc_tokens if not sequence_a_is_doc else truncated_query,
max_length=max_seq_length,
padding_strategy='right',
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
return_overflowing_tokens=True,
truncation_strategy='only_second'
truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first'
)
ids = encoded_dict['input_ids']
print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
non_padded_ids = ids[:ids.index(tokenizer.pad_token_id)] if tokenizer.pad_token_id in ids else ids
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
token_to_orig_map = {}
for i in range(paragraph_len):
token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[0 + i]
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
token_to_orig_map[index] = tok_to_orig_index[0 + i]
encoded_dict["paragraph_len"] = paragraph_len
encoded_dict["tokens"] = tokens
@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict["length"] = paragraph_len
spans.append(encoded_dict)
print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
# print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
while len(spans) * doc_stride < len(all_doc_tokens) and "overflowing_tokens" in encoded_dict:
overflowing_tokens = encoded_dict['overflowing_tokens']
print("OVERFLOW", len(overflowing_tokens))
overflowing_tokens = encoded_dict["overflowing_tokens"]
encoded_dict = tokenizer.encode_plus(
truncated_query,
overflowing_tokens,
truncated_query if not sequence_a_is_doc else overflowing_tokens,
overflowing_tokens if not sequence_a_is_doc else truncated_query,
max_length=max_seq_length,
return_overflowing_tokens=True,
padding_strategy='right',
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second'
truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first'
)
ids = encoded_dict['input_ids']
print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
# print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
# print(encoded_dict["input_ids"].index(tokenizer.pad_token_id) if tokenizer.pad_token_id in encoded_dict["input_ids"] else None)
# print(len(spans) * doc_stride, len(all_doc_tokens))
# Length of the document without the query
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
if tokenizer.pad_token_id in encoded_dict['input_ids']:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
else:
non_padded_ids = encoded_dict['input_ids']
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
token_to_orig_map = {}
for i in range(paragraph_len):
token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[len(spans) * doc_stride + i]
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
encoded_dict["paragraph_len"] = paragraph_len
encoded_dict["tokens"] = tokens
@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict["start"] = len(spans) * doc_stride
encoded_dict["length"] = paragraph_len
# split_token_index = doc_span.start + i
# token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
# is_max_context = _check_is_max_context(doc_spans, doc_span_index,
# split_token_index)
# token_is_max_context[len(tokens)] = is_max_context
# tokens.append(all_doc_tokens[split_token_index])
spans.append(encoded_dict)
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
index = j if sequence_a_is_doc else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
print("new span", len(spans))
for span in spans:
# Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids'])
# Convert all SEP indices to '0' before inversion
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 0
p_mask = np.minimum(p_mask, 1)
# Limit positive values to one
p_mask = 1 - np.minimum(p_mask, 1)
if not sequence_a_is_doc:
# Limit positive values to one
p_mask = 1 - p_mask
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
# Set the CLS index to '0'
p_mask[cls_index] = 0
print("new features length", len(new_features))
new_features.append(NewSquadFeatures(
span['input_ids'],
span['attention_mask'],
@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
print("OLD DOC CREATION BEGIN", start_offset, len(all_doc_tokens))
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
# print("Start offset is", start_offset, len(all_doc_tokens), "length is", length)
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
print("Done with this doc span, breaking out.", start_offset, length)
break
print("CHOOSING OFFSET", length, doc_stride)
start_offset += min(length, doc_stride)
print("OLD DOC CREATION END", start_offset)
print("old span", len(doc_spans))
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_to_orig_map = {}
@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask.append(0 if mask_padding_with_zero else 1)
segment_ids.append(pad_token_segment_id)
p_mask.append(1)
print("[OLD] Ids computed; position of the first padding", input_ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in input_ids else None)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# logger.info(
# "answer: %s" % (answer_text))
print("features length", len(features))
features.append(
SquadFeatures(
unique_id=unique_id,
@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert len(features) == len(new_features)
for i in range(len(features)):
print(i)
feature, new_feature = features[i], new_features[i]
input_ids = feature.input_ids
input_ids = [f if f not in [3,4,5] else 0 for f in feature.input_ids ]
input_mask = feature.input_mask
segment_ids = feature.segment_ids
cls_index = feature.cls_index
@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens = feature.tokens
token_to_orig_map = feature.token_to_orig_map
new_input_ids = new_feature.input_ids
new_input_ids = [f if f not in [3,4,5] else 0 for f in new_feature.input_ids]
new_input_mask = new_feature.attention_mask
new_segment_ids = new_feature.token_type_ids
new_cls_index = new_feature.cls_index
@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert example_index == new_example_index
assert paragraph_len == new_paragraph_len
assert token_is_max_context == new_token_is_max_context
tokens = [t if tokenizer.convert_tokens_to_ids(t) is not tokenizer.unk_token_id else tokenizer.unk_token for t in tokens]
assert tokens == new_tokens
assert token_to_orig_map == new_token_to_orig_map