mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
convert_examples_to_features code and small improvements.
This commit is contained in:
parent
83fdbd6043
commit
f2b873e995
@ -16,6 +16,15 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import logging
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
@ -31,10 +40,12 @@ class SwagExample(object):
|
||||
self.swag_id = swag_id
|
||||
self.context_sentence = context_sentence
|
||||
self.start_ending = start_ending
|
||||
self.ending_0 = ending_0
|
||||
self.ending_1 = ending_1
|
||||
self.ending_2 = ending_2
|
||||
self.ending_3 = ending_3
|
||||
self.endings = [
|
||||
ending_0,
|
||||
ending_1,
|
||||
ending_2,
|
||||
ending_3,
|
||||
]
|
||||
self.label = label
|
||||
|
||||
def __str__(self):
|
||||
@ -42,19 +53,37 @@ class SwagExample(object):
|
||||
|
||||
def __repr__(self):
|
||||
l = [
|
||||
f'swag_id: {self.swag_id}',
|
||||
f'context_sentence: {self.context_sentence}',
|
||||
f'start_ending: {self.start_ending}',
|
||||
f'ending_0: {self.ending_0}',
|
||||
f'ending_1: {self.ending_1}',
|
||||
f'ending_2: {self.ending_2}',
|
||||
f'ending_3: {self.ending_3}',
|
||||
f"swag_id: {self.swag_id}",
|
||||
f"context_sentence: {self.context_sentence}",
|
||||
f"start_ending: {self.start_ending}",
|
||||
f"ending_0: {self.endings[0]}",
|
||||
f"ending_1: {self.endings[1]}",
|
||||
f"ending_2: {self.endings[2]}",
|
||||
f"ending_3: {self.endings[3]}",
|
||||
]
|
||||
|
||||
if self.label is not None:
|
||||
l.append(f'label: {self.label}')
|
||||
l.append(f"label: {self.label}")
|
||||
|
||||
return ", ".join(l)
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_id,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
label_id
|
||||
):
|
||||
self.unique_id = unique_id
|
||||
self.example_id = example_id
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
return ', '.join(l)
|
||||
|
||||
def read_swag_examples(input_file, is_training):
|
||||
input_df = pd.read_csv(input_file)
|
||||
@ -67,7 +96,9 @@ def read_swag_examples(input_file, is_training):
|
||||
SwagExample(
|
||||
swag_id = row['fold-ind'],
|
||||
context_sentence = row['sent1'],
|
||||
start_ending = row['sent2'],
|
||||
start_ending = row['sent2'], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
ending_0 = row['ending0'],
|
||||
ending_1 = row['ending1'],
|
||||
ending_2 = row['ending2'],
|
||||
@ -79,9 +110,100 @@ def read_swag_examples(input_file, is_training):
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
is_training):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
# Swag is a multiple choice task. To perform this task using Bert,
|
||||
# we will use the formatting proposed in "Improving Language
|
||||
# Understanding by Generative Pre-Training" and suggested by
|
||||
# @jacobdevlin-google in this issue
|
||||
# https://github.com/google-research/bert/issues/38.
|
||||
#
|
||||
# Each choice will correspond to a sample on which we run the
|
||||
# inference. For a given Swag example, we will create the 4
|
||||
# following inputs:
|
||||
# - [CLS] context [SEP] choice_1 [SEP]
|
||||
# - [CLS] context [SEP] choice_2 [SEP]
|
||||
# - [CLS] context [SEP] choice_3 [SEP]
|
||||
# - [CLS] context [SEP] choice_4 [SEP]
|
||||
# The model will output a single value for each input. To get the
|
||||
# final decision of the model, we will run a softmax over these 4
|
||||
# outputs.
|
||||
features = []
|
||||
for example_index, example in enumerate(examples):
|
||||
context_tokens = tokenizer.tokenize(example.context_sentence)
|
||||
start_ending_tokens = tokenizer.tokenize(example.start_ending)
|
||||
|
||||
choices_features = []
|
||||
for ending_index, ending in enumerate(example.endings):
|
||||
# We create a copy of the context tokens in order to be
|
||||
# able to shrink it according to ending_tokens
|
||||
context_tokens_choice = context_tokens[:]
|
||||
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
|
||||
# Modifies `context_tokens_choice` and `ending_tokens` in
|
||||
# place so that the total length is less than the
|
||||
# specified length. Account for [CLS], [SEP], [SEP] with
|
||||
# "- 3"
|
||||
_truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3)
|
||||
|
||||
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
||||
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding = [0] * (max_seq_length - len(input_ids))
|
||||
input_ids += padding
|
||||
input_mask += padding
|
||||
segment_ids += padding
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
choices_features.append((tokens, input_ids, input_mask, segment_ids))
|
||||
|
||||
label = example.label
|
||||
if example_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info(f"swag_id: {example.swag_id}")
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info(f"choice: {choice_idx}")
|
||||
logger.info(f"tokens: {' '.join(tokens)}")
|
||||
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
||||
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
||||
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
||||
if is_training:
|
||||
logger.info(f"label: {label}")
|
||||
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
examples = read_swag_examples('data/train.csv', True)
|
||||
is_training = True
|
||||
max_seq_length = 80
|
||||
examples = read_swag_examples('data/train.csv', is_training)
|
||||
print(len(examples))
|
||||
for example in examples[:5]:
|
||||
print('###########################')
|
||||
print("###########################")
|
||||
print(example)
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|
||||
|
Loading…
Reference in New Issue
Block a user