mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 11:08:23 +06:00
211 lines
7.7 KiB
Python
211 lines
7.7 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""BERT finetuning runner."""
|
|
|
|
import pandas as pd
|
|
|
|
import logging
|
|
|
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
|
|
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
|
level = logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SwagExample(object):
|
|
"""A single training/test example for the SWAG dataset."""
|
|
def __init__(self,
|
|
swag_id,
|
|
context_sentence,
|
|
start_ending,
|
|
ending_0,
|
|
ending_1,
|
|
ending_2,
|
|
ending_3,
|
|
label = None):
|
|
self.swag_id = swag_id
|
|
self.context_sentence = context_sentence
|
|
self.start_ending = start_ending
|
|
self.endings = [
|
|
ending_0,
|
|
ending_1,
|
|
ending_2,
|
|
ending_3,
|
|
]
|
|
self.label = label
|
|
|
|
def __str__(self):
|
|
return self.__repr__()
|
|
|
|
def __repr__(self):
|
|
l = [
|
|
f"swag_id: {self.swag_id}",
|
|
f"context_sentence: {self.context_sentence}",
|
|
f"start_ending: {self.start_ending}",
|
|
f"ending_0: {self.endings[0]}",
|
|
f"ending_1: {self.endings[1]}",
|
|
f"ending_2: {self.endings[2]}",
|
|
f"ending_3: {self.endings[3]}",
|
|
]
|
|
|
|
if self.label is not None:
|
|
l.append(f"label: {self.label}")
|
|
|
|
return ", ".join(l)
|
|
|
|
|
|
class InputFeatures(object):
|
|
def __init__(self,
|
|
example_id,
|
|
choices_features,
|
|
label
|
|
):
|
|
self.example_id = example_id
|
|
self.choices_features = choices_features
|
|
self.label = label
|
|
|
|
def read_swag_examples(input_file, is_training):
|
|
input_df = pd.read_csv(input_file)
|
|
|
|
if is_training and 'label' not in input_df.columns:
|
|
raise ValueError(
|
|
"For training, the input file must contain a label column.")
|
|
|
|
examples = [
|
|
SwagExample(
|
|
swag_id = row['fold-ind'],
|
|
context_sentence = row['sent1'],
|
|
start_ending = row['sent2'], # in the swag dataset, the
|
|
# common beginning of each
|
|
# choice is stored in "sent2".
|
|
ending_0 = row['ending0'],
|
|
ending_1 = row['ending1'],
|
|
ending_2 = row['ending2'],
|
|
ending_3 = row['ending3'],
|
|
label = row['label'] if is_training else None
|
|
) for _, row in input_df.iterrows()
|
|
]
|
|
|
|
return examples
|
|
|
|
|
|
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|
is_training):
|
|
"""Loads a data file into a list of `InputBatch`s."""
|
|
|
|
# Swag is a multiple choice task. To perform this task using Bert,
|
|
# we will use the formatting proposed in "Improving Language
|
|
# Understanding by Generative Pre-Training" and suggested by
|
|
# @jacobdevlin-google in this issue
|
|
# https://github.com/google-research/bert/issues/38.
|
|
#
|
|
# Each choice will correspond to a sample on which we run the
|
|
# inference. For a given Swag example, we will create the 4
|
|
# following inputs:
|
|
# - [CLS] context [SEP] choice_1 [SEP]
|
|
# - [CLS] context [SEP] choice_2 [SEP]
|
|
# - [CLS] context [SEP] choice_3 [SEP]
|
|
# - [CLS] context [SEP] choice_4 [SEP]
|
|
# The model will output a single value for each input. To get the
|
|
# final decision of the model, we will run a softmax over these 4
|
|
# outputs.
|
|
features = []
|
|
for example_index, example in enumerate(examples):
|
|
context_tokens = tokenizer.tokenize(example.context_sentence)
|
|
start_ending_tokens = tokenizer.tokenize(example.start_ending)
|
|
|
|
choices_features = []
|
|
for ending_index, ending in enumerate(example.endings):
|
|
# We create a copy of the context tokens in order to be
|
|
# able to shrink it according to ending_tokens
|
|
context_tokens_choice = context_tokens[:]
|
|
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
|
|
# Modifies `context_tokens_choice` and `ending_tokens` in
|
|
# place so that the total length is less than the
|
|
# specified length. Account for [CLS], [SEP], [SEP] with
|
|
# "- 3"
|
|
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
|
|
|
|
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
|
|
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
|
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
input_mask = [1] * len(input_ids)
|
|
|
|
# Zero-pad up to the sequence length.
|
|
padding = [0] * (max_seq_length - len(input_ids))
|
|
input_ids += padding
|
|
input_mask += padding
|
|
segment_ids += padding
|
|
|
|
assert len(input_ids) == max_seq_length
|
|
assert len(input_mask) == max_seq_length
|
|
assert len(segment_ids) == max_seq_length
|
|
|
|
choices_features.append((tokens, input_ids, input_mask, segment_ids))
|
|
|
|
label = example.label
|
|
if example_index < 5:
|
|
logger.info("*** Example ***")
|
|
logger.info(f"swag_id: {example.swag_id}")
|
|
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
|
logger.info(f"choice: {choice_idx}")
|
|
logger.info(f"tokens: {' '.join(tokens)}")
|
|
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
|
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
|
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
|
if is_training:
|
|
logger.info(f"label: {label}")
|
|
|
|
features.append(
|
|
InputFeatures(
|
|
example_id = example.swag_id,
|
|
choices_features = choices_features,
|
|
label = label
|
|
)
|
|
)
|
|
|
|
return features
|
|
|
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|
"""Truncates a sequence pair in place to the maximum length."""
|
|
|
|
# This is a simple heuristic which will always truncate the longer sequence
|
|
# one token at a time. This makes more sense than truncating an equal percent
|
|
# of tokens from each, since if one sequence is very short then each token
|
|
# that's truncated likely contains more information than a longer sequence.
|
|
while True:
|
|
total_length = len(tokens_a) + len(tokens_b)
|
|
if total_length <= max_length:
|
|
break
|
|
if len(tokens_a) > len(tokens_b):
|
|
tokens_a.pop()
|
|
else:
|
|
tokens_b.pop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
is_training = True
|
|
max_seq_length = 80
|
|
examples = read_swag_examples('data/train.csv', is_training)
|
|
print(len(examples))
|
|
for example in examples[:5]:
|
|
print("###########################")
|
|
print(example)
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
|