Removing the dependency to pandas and using the csv module to load data.

This commit is contained in:
Grégory Châtel 2018-12-10 17:45:23 +01:00
parent 0876b77f7f
commit df34f22854

View File

@ -14,13 +14,12 @@
# limitations under the License.
"""BERT finetuning runner."""
import pandas as pd
import logging
import os
import argparse
import random
from tqdm import tqdm, trange
import csv
import numpy as np
import torch
@ -100,25 +99,28 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training):
input_df = pd.read_csv(input_file)
with open(input_file, 'r') as f:
reader = csv.reader(f)
lines = list(reader)
if is_training and 'label' not in input_df.columns:
if is_training and lines[0][-1] != 'label':
raise ValueError(
"For training, the input file must contain a label column.")
"For training, the input file must contain a label column."
)
examples = [
SwagExample(
swag_id = row['fold-ind'],
context_sentence = row['sent1'],
start_ending = row['sent2'], # in the swag dataset, the
swag_id = line[2],
context_sentence = line[4],
start_ending = line[5], # in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
ending_0 = row['ending0'],
ending_1 = row['ending1'],
ending_2 = row['ending2'],
ending_3 = row['ending3'],
label = row['label'] if is_training else None
) for _, row in input_df.iterrows()
ending_0 = line[7],
ending_1 = line[8],
ending_2 = line[9],
ending_3 = line[10],
label = int(line[11]) if is_training else None
) for line in lines[1:] # we skip the line with the column names
]
return examples