mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-22 14:00:33 +06:00
Added script for training a discriminator for pplm to use
This commit is contained in:
parent
34a83faabe
commit
0b51fba20b
@ -34,6 +34,7 @@ import torch.nn.functional as F
|
|||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
from examples.run_pplm_discrim_train import ClassificationHead
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
from transformers.file_utils import cached_path
|
from transformers.file_utils import cached_path
|
||||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||||
@ -108,24 +109,6 @@ def top_k_filter(logits, k, probs=False):
|
|||||||
logits)
|
logits)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationHead(torch.nn.Module):
|
|
||||||
""" Classification Head for the transformer """
|
|
||||||
|
|
||||||
def __init__(self, class_size=5, embed_size=2048):
|
|
||||||
super(ClassificationHead, self).__init__()
|
|
||||||
self.class_size = class_size
|
|
||||||
self.embed_size = embed_size
|
|
||||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
|
||||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
|
||||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
|
||||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
|
||||||
# hidden_state = self.mlp2(hidden_state)
|
|
||||||
logits = self.mlp(hidden_state)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def perturb_past(past, model, prev, args, classifier, good_index=None,
|
def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||||
stepsize=0.01, vocab_size=50257,
|
stepsize=0.01, vocab_size=50257,
|
||||||
original_probs=None, accumulated_hidden=None, true_past=None,
|
original_probs=None, accumulated_hidden=None, true_past=None,
|
||||||
|
582
examples/run_pplm_discrim_train.py
Normal file
582
examples/run_pplm_discrim_train.py
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
#! /usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# This code is licensed under a non-commercial license.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data as data
|
||||||
|
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||||
|
from torchtext import data as torchtext_data
|
||||||
|
from torchtext import datasets
|
||||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
EPSILON = 1e-10
|
||||||
|
device = 'cpu'
|
||||||
|
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||||
|
max_length_seq = 100
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(torch.nn.Module):
|
||||||
|
"""Classification Head for transformer encoders"""
|
||||||
|
|
||||||
|
def __init__(self, class_size, embed_size):
|
||||||
|
super(ClassificationHead, self).__init__()
|
||||||
|
self.class_size = class_size
|
||||||
|
self.embed_size = embed_size
|
||||||
|
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||||
|
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||||
|
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||||
|
# hidden_state = self.mlp2(hidden_state)
|
||||||
|
logits = self.mlp(hidden_state)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(torch.nn.Module):
|
||||||
|
"""Transformer encoder followed by a Classification Head"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_size,
|
||||||
|
pretrained_model="gpt2-medium",
|
||||||
|
cached_mode=False
|
||||||
|
):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||||
|
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
||||||
|
self.embed_size = self.encoder.transformer.config.hidden_size
|
||||||
|
self.classifier_head = ClassificationHead(
|
||||||
|
class_size=class_size,
|
||||||
|
embed_size=self.embed_size
|
||||||
|
)
|
||||||
|
self.cached_mode = cached_mode
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.classifier_head
|
||||||
|
|
||||||
|
def train_custom(self):
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
pass
|
||||||
|
self.classifier_head.train()
|
||||||
|
|
||||||
|
def avg_representation(self, x):
|
||||||
|
mask = x.ne(0).unsqueeze(2).repeat(
|
||||||
|
1, 1, self.embed_size
|
||||||
|
).float().to(device).detach()
|
||||||
|
hidden, _ = self.encoder.transformer(x)
|
||||||
|
masked_hidden = hidden * mask
|
||||||
|
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
||||||
|
torch.sum(mask, dim=1).detach() + EPSILON
|
||||||
|
)
|
||||||
|
return avg_hidden
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.cached_mode:
|
||||||
|
avg_hidden = x.to(device)
|
||||||
|
else:
|
||||||
|
avg_hidden = self.avg_representation(x)
|
||||||
|
|
||||||
|
logits = self.classifier_head(avg_hidden)
|
||||||
|
probs = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(data.Dataset):
|
||||||
|
def __init__(self, X, y):
|
||||||
|
"""Reads source and target sequences from txt files."""
|
||||||
|
self.X = X
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""Returns one data pair (source and target)."""
|
||||||
|
data = {}
|
||||||
|
data['X'] = self.X[index]
|
||||||
|
data['y'] = self.y[index]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(data):
|
||||||
|
def pad_sequences(sequences):
|
||||||
|
lengths = [len(seq) for seq in sequences]
|
||||||
|
|
||||||
|
padded_sequences = torch.zeros(
|
||||||
|
len(sequences),
|
||||||
|
max(lengths)
|
||||||
|
).long() # padding index 0
|
||||||
|
|
||||||
|
for i, seq in enumerate(sequences):
|
||||||
|
end = lengths[i]
|
||||||
|
padded_sequences[i, :end] = seq[:end]
|
||||||
|
|
||||||
|
return padded_sequences, lengths
|
||||||
|
|
||||||
|
item_info = {}
|
||||||
|
for key in data[0].keys():
|
||||||
|
item_info[key] = [d[key] for d in data]
|
||||||
|
|
||||||
|
x_batch, _ = pad_sequences(item_info['X'])
|
||||||
|
y_batch = torch.tensor(item_info['y'], dtype=torch.long)
|
||||||
|
|
||||||
|
return x_batch, y_batch
|
||||||
|
|
||||||
|
|
||||||
|
def cached_collate_fn(data):
|
||||||
|
item_info = {}
|
||||||
|
for key in data[0].keys():
|
||||||
|
item_info[key] = [d[key] for d in data]
|
||||||
|
|
||||||
|
x_batch = torch.cat(item_info['X'], 0)
|
||||||
|
y_batch = torch.tensor(item_info['y'], dtype=torch.long)
|
||||||
|
|
||||||
|
return x_batch, y_batch
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(data_loader, discriminator, optimizer,
|
||||||
|
epoch=0, log_interval=10):
|
||||||
|
samples_so_far = 0
|
||||||
|
discriminator.train_custom()
|
||||||
|
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||||
|
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
output_t = discriminator(input_t)
|
||||||
|
loss = F.nll_loss(output_t, target_t)
|
||||||
|
loss.backward(retain_graph=True)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
samples_so_far += len(input_t)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
print(
|
||||||
|
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||||
|
epoch + 1,
|
||||||
|
samples_so_far, len(data_loader.dataset),
|
||||||
|
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_performance(data_loader, discriminator):
|
||||||
|
discriminator.eval()
|
||||||
|
test_loss = 0
|
||||||
|
correct = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for input_t, target_t in data_loader:
|
||||||
|
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||||
|
output_t = discriminator(input_t)
|
||||||
|
# sum up batch loss
|
||||||
|
test_loss += F.nll_loss(output_t, target_t, reduction='sum').item()
|
||||||
|
# get the index of the max log-probability
|
||||||
|
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||||
|
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||||
|
|
||||||
|
test_loss /= len(data_loader.dataset)
|
||||||
|
|
||||||
|
print(
|
||||||
|
'Performance on test set: '
|
||||||
|
'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
|
||||||
|
test_loss, correct, len(data_loader.dataset),
|
||||||
|
100. * correct / len(data_loader.dataset)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(input_sentence, model, classes, cached=False):
|
||||||
|
input_t = model.tokenizer.encode(input_sentence)
|
||||||
|
input_t = torch.tensor([input_t], dtype=torch.long)
|
||||||
|
if cached:
|
||||||
|
input_t = model.avg_representation(input_t)
|
||||||
|
|
||||||
|
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||||
|
print('Input sentence:', input_sentence)
|
||||||
|
print('Predictions:', ", ".join(
|
||||||
|
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
||||||
|
zip(classes, log_probs)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
for batch_idx, (x, y) in enumerate(data_loader):
|
||||||
|
with torch.no_grad():
|
||||||
|
x = x.to(device)
|
||||||
|
avg_rep = discriminator.avg_representation(x).cpu().detach()
|
||||||
|
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
|
||||||
|
xs += avg_rep_list
|
||||||
|
ys += y.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset=Dataset(xs, ys),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
collate_fn=cached_collate_fn)
|
||||||
|
|
||||||
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def train_discriminator(
|
||||||
|
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
|
||||||
|
epochs=10, batch_size=64, log_interval=10,
|
||||||
|
save_model=False, cached=False, use_cuda=False):
|
||||||
|
if use_cuda:
|
||||||
|
global device
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
print('Preprocessing {} dataset...'.format(dataset))
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
if dataset == 'SST':
|
||||||
|
idx2class = ["positive", "negative", "very positive", "very negative",
|
||||||
|
"neutral"]
|
||||||
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
|
discriminator = Discriminator(
|
||||||
|
class_size=len(idx2class),
|
||||||
|
pretrained_model=pretrained_model,
|
||||||
|
cached_mode=cached
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
text = torchtext_data.Field()
|
||||||
|
label = torchtext_data.Field(sequential=False)
|
||||||
|
train_data, val_data, test_data = datasets.SST.splits(
|
||||||
|
text,
|
||||||
|
label,
|
||||||
|
fine_grained=True,
|
||||||
|
train_subtrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
for i in range(len(train_data)):
|
||||||
|
seq = TreebankWordDetokenizer().detokenize(
|
||||||
|
vars(train_data[i])["text"]
|
||||||
|
)
|
||||||
|
seq = discriminator.tokenizer.encode(seq)
|
||||||
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
|
x.append(seq)
|
||||||
|
y.append(class2idx[vars(train_data[i])["label"]])
|
||||||
|
train_dataset = Dataset(x, y)
|
||||||
|
|
||||||
|
test_x = []
|
||||||
|
test_y = []
|
||||||
|
for i in range(len(test_data)):
|
||||||
|
seq = TreebankWordDetokenizer().detokenize(
|
||||||
|
vars(test_data[i])["text"]
|
||||||
|
)
|
||||||
|
seq = discriminator.tokenizer.encode(seq)
|
||||||
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||||
|
test_x.append(seq)
|
||||||
|
test_y.append(class2idx[vars(test_data[i])["label"]])
|
||||||
|
test_dataset = Dataset(test_x, test_y)
|
||||||
|
|
||||||
|
discriminator_meta = {
|
||||||
|
"class_size": len(idx2class),
|
||||||
|
"embed_size": discriminator.embed_size,
|
||||||
|
"pretrained_model": pretrained_model,
|
||||||
|
"class_vocab": class2idx,
|
||||||
|
"default_class": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif dataset == 'clickbait':
|
||||||
|
idx2class = ["non_clickbait", "clickbait"]
|
||||||
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
|
discriminator = Discriminator(
|
||||||
|
class_size=len(idx2class),
|
||||||
|
pretrained_model=pretrained_model,
|
||||||
|
cached_mode=cached
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||||
|
data = []
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
try:
|
||||||
|
data.append(eval(line))
|
||||||
|
except:
|
||||||
|
print('Error evaluating line {}: {}'.format(
|
||||||
|
i, line
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
y = []
|
||||||
|
for i, d in enumerate(data):
|
||||||
|
try:
|
||||||
|
seq = discriminator.tokenizer.encode(d["text"])
|
||||||
|
|
||||||
|
if len(seq) < max_length_seq:
|
||||||
|
seq = torch.tensor(
|
||||||
|
[50256] + seq, device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Line {} is longer than maximum length {}".format(
|
||||||
|
i, max_length_seq
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
x.append(seq)
|
||||||
|
y.append(d['label'])
|
||||||
|
except:
|
||||||
|
print("Error tokenizing line {}, skipping it".format(i))
|
||||||
|
pass
|
||||||
|
|
||||||
|
full_dataset = Dataset(x, y)
|
||||||
|
train_size = int(0.9 * len(full_dataset))
|
||||||
|
test_size = len(full_dataset) - train_size
|
||||||
|
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||||
|
full_dataset, [train_size, test_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminator_meta = {
|
||||||
|
"class_size": len(idx2class),
|
||||||
|
"embed_size": discriminator.embed_size,
|
||||||
|
"pretrained_model": pretrained_model,
|
||||||
|
"class_vocab": class2idx,
|
||||||
|
"default_class": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif dataset == 'toxic':
|
||||||
|
idx2class = ["non_toxic", "toxic"]
|
||||||
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
|
discriminator = Discriminator(
|
||||||
|
class_size=len(idx2class),
|
||||||
|
pretrained_model=pretrained_model,
|
||||||
|
cached_mode=cached
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with open("datasets/toxic/toxic_train.txt") as f:
|
||||||
|
data = []
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
try:
|
||||||
|
data.append(eval(line))
|
||||||
|
except:
|
||||||
|
print('Error evaluating line {}: {}'.format(
|
||||||
|
i, line
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
for i, d in enumerate(data):
|
||||||
|
try:
|
||||||
|
seq = discriminator.tokenizer.encode(d["text"])
|
||||||
|
|
||||||
|
if len(seq) < max_length_seq:
|
||||||
|
seq = torch.tensor(
|
||||||
|
[50256] + seq, device=device, dtype=torch.long
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Line {} is longer than maximum length {}".format(
|
||||||
|
i, max_length_seq
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
x.append(seq)
|
||||||
|
y.append(int(np.sum(d['label']) > 0))
|
||||||
|
except:
|
||||||
|
print("Error tokenizing line {}, skipping it".format(i))
|
||||||
|
pass
|
||||||
|
|
||||||
|
full_dataset = Dataset(x, y)
|
||||||
|
train_size = int(0.9 * len(full_dataset))
|
||||||
|
test_size = len(full_dataset) - train_size
|
||||||
|
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||||
|
full_dataset, [train_size, test_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminator_meta = {
|
||||||
|
"class_size": len(idx2class),
|
||||||
|
"embed_size": discriminator.embed_size,
|
||||||
|
"pretrained_model": pretrained_model,
|
||||||
|
"class_vocab": class2idx,
|
||||||
|
"default_class": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
else: # if dataset == 'generic':
|
||||||
|
# This assumes the input dataset is a TSV with the following structure:
|
||||||
|
# class \t text
|
||||||
|
|
||||||
|
if dataset_fp is None:
|
||||||
|
raise ValueError('When generic dataset is selected, '
|
||||||
|
'dataset_fp needs to be specified aswell.')
|
||||||
|
|
||||||
|
classes = set()
|
||||||
|
with open(dataset_fp) as f:
|
||||||
|
csv_reader = csv.reader(f, delimiter='\t')
|
||||||
|
for row in csv_reader:
|
||||||
|
classes.add(row[0])
|
||||||
|
|
||||||
|
idx2class = sorted(classes)
|
||||||
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
|
||||||
|
discriminator = Discriminator(
|
||||||
|
class_size=len(idx2class),
|
||||||
|
pretrained_model=pretrained_model,
|
||||||
|
cached_mode=cached
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
with open(dataset_fp) as f:
|
||||||
|
csv_reader = csv.reader(f, delimiter='\t')
|
||||||
|
for i, row in enumerate(csv_reader):
|
||||||
|
label = row[0]
|
||||||
|
text = row[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
seq = discriminator.tokenizer.encode(text)
|
||||||
|
if (len(seq) < max_length_seq):
|
||||||
|
seq = torch.tensor(
|
||||||
|
[50256] + seq,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Line {} is longer than maximum length {}".format(
|
||||||
|
i, max_length_seq
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
x.append(seq)
|
||||||
|
y.append(class2idx[label])
|
||||||
|
|
||||||
|
except:
|
||||||
|
print("Error tokenizing line {}, skipping it".format(i))
|
||||||
|
pass
|
||||||
|
|
||||||
|
full_dataset = Dataset(x, y)
|
||||||
|
train_size = int(0.9 * len(full_dataset))
|
||||||
|
test_size = len(full_dataset) - train_size
|
||||||
|
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||||
|
full_dataset,
|
||||||
|
[train_size, test_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminator_meta = {
|
||||||
|
"class_size": len(idx2class),
|
||||||
|
"embed_size": discriminator.embed_size,
|
||||||
|
"pretrained_model": pretrained_model,
|
||||||
|
"class_vocab": class2idx,
|
||||||
|
"default_class": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print('Preprocessed {} data points'.format(
|
||||||
|
len(train_dataset) + len(test_dataset))
|
||||||
|
)
|
||||||
|
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
train_loader = get_cached_data_loader(
|
||||||
|
train_dataset, batch_size, discriminator, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = get_cached_data_loader(
|
||||||
|
test_dataset, batch_size, discriminator
|
||||||
|
)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print("Building representation cache took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
|
else:
|
||||||
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
if save_model:
|
||||||
|
with open("{}_classifier_head_meta.json".format(dataset),
|
||||||
|
"w") as meta_file:
|
||||||
|
json.dump(discriminator_meta, meta_file)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
start = time.time()
|
||||||
|
print('\nEpoch', epoch + 1)
|
||||||
|
|
||||||
|
train_epoch(
|
||||||
|
discriminator=discriminator,
|
||||||
|
data_loader=train_loader,
|
||||||
|
optimizer=optimizer,
|
||||||
|
epoch=epoch,
|
||||||
|
log_interval=log_interval
|
||||||
|
)
|
||||||
|
evaluate_performance(
|
||||||
|
data_loader=test_loader,
|
||||||
|
discriminator=discriminator
|
||||||
|
)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print("Epoch took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
|
print("\nExample prediction")
|
||||||
|
predict(example_sentence, discriminator, idx2class, cached)
|
||||||
|
|
||||||
|
if save_model:
|
||||||
|
# torch.save(discriminator.state_dict(),
|
||||||
|
# "{}_discriminator_{}.pt".format(
|
||||||
|
# args.dataset, epoch
|
||||||
|
# ))
|
||||||
|
torch.save(discriminator.get_classifier().state_dict(),
|
||||||
|
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Train a discriminator on top of GPT-2 representations')
|
||||||
|
parser.add_argument('--dataset', type=str, default='SST',
|
||||||
|
choices=('SST', 'clickbait', 'toxic', 'generic'),
|
||||||
|
help='dataset to train the discriminator on.'
|
||||||
|
'In case of generic, the dataset is expected'
|
||||||
|
'to be a TSBV file with structure: class \\t text')
|
||||||
|
parser.add_argument('--dataset_fp', type=str, default='',
|
||||||
|
help='File path of the dataset to use. '
|
||||||
|
'Needed only in case of generic datadset')
|
||||||
|
parser.add_argument('--pretrained_model', type=str, default='gpt2-medium',
|
||||||
|
help='Pretrained model to use as encoder')
|
||||||
|
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||||
|
help='Number of training epochs')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
|
||||||
|
help='input batch size for training (default: 64)')
|
||||||
|
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
|
||||||
|
help='how many batches to wait before logging training status')
|
||||||
|
parser.add_argument('--save_model', action='store_true',
|
||||||
|
help='whether to save the model')
|
||||||
|
parser.add_argument('--cached', action='store_true',
|
||||||
|
help='whether to cache the input representations')
|
||||||
|
parser.add_argument('--use_cuda', action='store_true',
|
||||||
|
help='use to turn on cuda')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train_discriminator(**(vars(args)))
|
Loading…
Reference in New Issue
Block a user