mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Works for XLNet
This commit is contained in:
parent
a5a8a6175f
commit
c3ba645237
@ -16,6 +16,7 @@
|
||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from transformers.data.processors.squad import SquadV1Processor
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features, read_squad_examples as sread_squad_examples
|
||||
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
from utils_squad import (RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
@ -308,9 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
|
||||
examples = examples[:10]
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
keep_n_examples = 1000
|
||||
processor = SquadV1Processor()
|
||||
values = processor.get_dev_examples("examples/squad")
|
||||
examples = values[:keep_n_examples]
|
||||
features = squad_convert_examples_to_features(examples=exampless,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
@ -320,29 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
|
||||
exampless = sread_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
exampless = exampless[:10]
|
||||
features2 = squad_convert_examples_to_features(examples=exampless,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
|
||||
print(features2)
|
||||
|
||||
for i in range(len(features)):
|
||||
assert features[i] == features2[i]
|
||||
print("Equal")
|
||||
|
||||
print("DONE")
|
||||
|
||||
import sys
|
||||
sys.exit()
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
|
@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
sequence_a_is_doc=False):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
cls_token = tokenizer.cls_token
|
||||
sep_token = tokenizer.sep_token
|
||||
|
||||
# Defining helper methods
|
||||
unique_id = 1000000000
|
||||
|
||||
@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
encoded_dict = tokenizer.encode_plus(
|
||||
truncated_query,
|
||||
all_doc_tokens,
|
||||
truncated_query if not sequence_a_is_doc else all_doc_tokens,
|
||||
all_doc_tokens if not sequence_a_is_doc else truncated_query,
|
||||
max_length=max_seq_length,
|
||||
padding_strategy='right',
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
return_overflowing_tokens=True,
|
||||
truncation_strategy='only_second'
|
||||
truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first'
|
||||
)
|
||||
|
||||
ids = encoded_dict['input_ids']
|
||||
print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
|
||||
non_padded_ids = ids[:ids.index(tokenizer.pad_token_id)] if tokenizer.pad_token_id in ids else ids
|
||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||
|
||||
token_to_orig_map = {}
|
||||
for i in range(paragraph_len):
|
||||
token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[0 + i]
|
||||
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
|
||||
token_to_orig_map[index] = tok_to_orig_index[0 + i]
|
||||
|
||||
encoded_dict["paragraph_len"] = paragraph_len
|
||||
encoded_dict["tokens"] = tokens
|
||||
@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
encoded_dict["length"] = paragraph_len
|
||||
|
||||
spans.append(encoded_dict)
|
||||
print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
|
||||
# print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
|
||||
|
||||
while len(spans) * doc_stride < len(all_doc_tokens) and "overflowing_tokens" in encoded_dict:
|
||||
|
||||
overflowing_tokens = encoded_dict['overflowing_tokens']
|
||||
|
||||
print("OVERFLOW", len(overflowing_tokens))
|
||||
|
||||
overflowing_tokens = encoded_dict["overflowing_tokens"]
|
||||
encoded_dict = tokenizer.encode_plus(
|
||||
truncated_query,
|
||||
overflowing_tokens,
|
||||
truncated_query if not sequence_a_is_doc else overflowing_tokens,
|
||||
overflowing_tokens if not sequence_a_is_doc else truncated_query,
|
||||
max_length=max_seq_length,
|
||||
return_overflowing_tokens=True,
|
||||
padding_strategy='right',
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
truncation_strategy='only_second'
|
||||
truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first'
|
||||
)
|
||||
|
||||
ids = encoded_dict['input_ids']
|
||||
print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
|
||||
# print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
|
||||
|
||||
# print(encoded_dict["input_ids"].index(tokenizer.pad_token_id) if tokenizer.pad_token_id in encoded_dict["input_ids"] else None)
|
||||
# print(len(spans) * doc_stride, len(all_doc_tokens))
|
||||
|
||||
|
||||
# Length of the document without the query
|
||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||
|
||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||
else:
|
||||
non_padded_ids = encoded_dict['input_ids']
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||
|
||||
token_to_orig_map = {}
|
||||
for i in range(paragraph_len):
|
||||
token_to_orig_map[len(truncated_query) + sequence_added_tokens + i] = tok_to_orig_index[len(spans) * doc_stride + i]
|
||||
index = len(truncated_query) + sequence_added_tokens + i if not sequence_a_is_doc else i
|
||||
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
||||
|
||||
encoded_dict["paragraph_len"] = paragraph_len
|
||||
encoded_dict["tokens"] = tokens
|
||||
@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
encoded_dict["start"] = len(spans) * doc_stride
|
||||
encoded_dict["length"] = paragraph_len
|
||||
|
||||
# split_token_index = doc_span.start + i
|
||||
# token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
# is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
# split_token_index)
|
||||
# token_is_max_context[len(tokens)] = is_max_context
|
||||
# tokens.append(all_doc_tokens[split_token_index])
|
||||
|
||||
spans.append(encoded_dict)
|
||||
|
||||
for doc_span_index in range(len(spans)):
|
||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||
index = spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
index = j if sequence_a_is_doc else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||
|
||||
print("new span", len(spans))
|
||||
for span in spans:
|
||||
# Identify the position of the CLS token
|
||||
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
||||
@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = np.array(span['token_type_ids'])
|
||||
|
||||
# Convert all SEP indices to '0' before inversion
|
||||
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 0
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
# Limit positive values to one
|
||||
p_mask = 1 - np.minimum(p_mask, 1)
|
||||
if not sequence_a_is_doc:
|
||||
# Limit positive values to one
|
||||
p_mask = 1 - p_mask
|
||||
|
||||
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
|
||||
|
||||
# Set the CLS index to '0'
|
||||
p_mask[cls_index] = 0
|
||||
|
||||
print("new features length", len(new_features))
|
||||
|
||||
new_features.append(NewSquadFeatures(
|
||||
span['input_ids'],
|
||||
span['attention_mask'],
|
||||
@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
print("OLD DOC CREATION BEGIN", start_offset, len(all_doc_tokens))
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
# print("Start offset is", start_offset, len(all_doc_tokens), "length is", length)
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
print("Done with this doc span, breaking out.", start_offset, length)
|
||||
break
|
||||
print("CHOOSING OFFSET", length, doc_stride)
|
||||
start_offset += min(length, doc_stride)
|
||||
print("OLD DOC CREATION END", start_offset)
|
||||
|
||||
print("old span", len(doc_spans))
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
input_mask.append(0 if mask_padding_with_zero else 1)
|
||||
segment_ids.append(pad_token_segment_id)
|
||||
p_mask.append(1)
|
||||
print("[OLD] Ids computed; position of the first padding", input_ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in input_ids else None)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# logger.info(
|
||||
# "answer: %s" % (answer_text))
|
||||
|
||||
print("features length", len(features))
|
||||
features.append(
|
||||
SquadFeatures(
|
||||
unique_id=unique_id,
|
||||
@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
|
||||
assert len(features) == len(new_features)
|
||||
for i in range(len(features)):
|
||||
print(i)
|
||||
feature, new_feature = features[i], new_features[i]
|
||||
|
||||
input_ids = feature.input_ids
|
||||
input_ids = [f if f not in [3,4,5] else 0 for f in feature.input_ids ]
|
||||
input_mask = feature.input_mask
|
||||
segment_ids = feature.segment_ids
|
||||
cls_index = feature.cls_index
|
||||
@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
tokens = feature.tokens
|
||||
token_to_orig_map = feature.token_to_orig_map
|
||||
|
||||
new_input_ids = new_feature.input_ids
|
||||
new_input_ids = [f if f not in [3,4,5] else 0 for f in new_feature.input_ids]
|
||||
new_input_mask = new_feature.attention_mask
|
||||
new_segment_ids = new_feature.token_type_ids
|
||||
new_cls_index = new_feature.cls_index
|
||||
@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
assert example_index == new_example_index
|
||||
assert paragraph_len == new_paragraph_len
|
||||
assert token_is_max_context == new_token_is_max_context
|
||||
|
||||
tokens = [t if tokenizer.convert_tokens_to_ids(t) is not tokenizer.unk_token_id else tokenizer.unk_token for t in tokens]
|
||||
|
||||
assert tokens == new_tokens
|
||||
assert token_to_orig_map == new_token_to_orig_map
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user