mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Added tqdm to preprocessing
This commit is contained in:
parent
afc7dcd94d
commit
611961ade7
@ -18,13 +18,14 @@ import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
EPSILON = 1e-10
|
||||
device = 'cpu'
|
||||
device = "cpu"
|
||||
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||
max_length_seq = 100
|
||||
|
||||
@ -109,8 +110,8 @@ class Dataset(data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
"""Returns one data pair (source and target)."""
|
||||
data = {}
|
||||
data['X'] = self.X[index]
|
||||
data['y'] = self.y[index]
|
||||
data["X"] = self.X[index]
|
||||
data["y"] = self.y[index]
|
||||
return data
|
||||
|
||||
|
||||
@ -133,8 +134,8 @@ def collate_fn(data):
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch, _ = pad_sequences(item_info['X'])
|
||||
y_batch = torch.tensor(item_info['y'], dtype=torch.long)
|
||||
x_batch, _ = pad_sequences(item_info["X"])
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
@ -144,8 +145,8 @@ def cached_collate_fn(data):
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch = torch.cat(item_info['X'], 0)
|
||||
y_batch = torch.tensor(item_info['y'], dtype=torch.long)
|
||||
x_batch = torch.cat(item_info["X"], 0)
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
@ -168,7 +169,7 @@ def train_epoch(data_loader, discriminator, optimizer,
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
print(
|
||||
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch + 1,
|
||||
samples_so_far, len(data_loader.dataset),
|
||||
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
||||
@ -185,7 +186,7 @@ def evaluate_performance(data_loader, discriminator):
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
output_t = discriminator(input_t)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output_t, target_t, reduction='sum').item()
|
||||
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
||||
# get the index of the max log-probability
|
||||
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||
@ -193,8 +194,8 @@ def evaluate_performance(data_loader, discriminator):
|
||||
test_loss /= len(data_loader.dataset)
|
||||
|
||||
print(
|
||||
'Performance on test set: '
|
||||
'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
|
||||
"Performance on test set: "
|
||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||
test_loss, correct, len(data_loader.dataset),
|
||||
100. * correct / len(data_loader.dataset)
|
||||
)
|
||||
@ -208,8 +209,8 @@ def predict(input_sentence, model, classes, cached=False):
|
||||
input_t = model.avg_representation(input_t)
|
||||
|
||||
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||
print('Input sentence:', input_sentence)
|
||||
print('Predictions:', ", ".join(
|
||||
print("Input sentence:", input_sentence)
|
||||
print("Predictions:", ", ".join(
|
||||
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
||||
zip(classes, log_probs)
|
||||
))
|
||||
@ -222,7 +223,7 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
||||
|
||||
xs = []
|
||||
ys = []
|
||||
for batch_idx, (x, y) in enumerate(data_loader):
|
||||
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
avg_rep = discriminator.avg_representation(x).cpu().detach()
|
||||
@ -240,16 +241,16 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
||||
|
||||
|
||||
def train_discriminator(
|
||||
dataset, dataset_fp=None, pretrained_model='gpt2-medium',
|
||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
||||
epochs=10, batch_size=64, log_interval=10,
|
||||
save_model=False, cached=False, no_cuda=False):
|
||||
global device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print('Preprocessing {} dataset...'.format(dataset))
|
||||
print("Preprocessing {} dataset...".format(dataset))
|
||||
start = time.time()
|
||||
|
||||
if dataset == 'SST':
|
||||
if dataset == "SST":
|
||||
idx2class = ["positive", "negative", "very positive", "very negative",
|
||||
"neutral"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
@ -271,7 +272,7 @@ def train_discriminator(
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for i in range(len(train_data)):
|
||||
for i in trange(len(train_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(train_data[i])["text"]
|
||||
)
|
||||
@ -283,7 +284,7 @@ def train_discriminator(
|
||||
|
||||
test_x = []
|
||||
test_y = []
|
||||
for i in range(len(test_data)):
|
||||
for i in trange(len(test_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(test_data[i])["text"]
|
||||
)
|
||||
@ -301,7 +302,7 @@ def train_discriminator(
|
||||
"default_class": 2,
|
||||
}
|
||||
|
||||
elif dataset == 'clickbait':
|
||||
elif dataset == "clickbait":
|
||||
idx2class = ["non_clickbait", "clickbait"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
@ -317,31 +318,33 @@ def train_discriminator(
|
||||
try:
|
||||
data.append(eval(line))
|
||||
except:
|
||||
print('Error evaluating line {}: {}'.format(
|
||||
print("Error evaluating line {}: {}".format(
|
||||
i, line
|
||||
))
|
||||
continue
|
||||
x = []
|
||||
y = []
|
||||
y = []
|
||||
for i, d in enumerate(data):
|
||||
try:
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(d['label'])
|
||||
except:
|
||||
print("Error tokenizing line {}, skipping it".format(i))
|
||||
pass
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(d["label"])
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
@ -358,7 +361,7 @@ def train_discriminator(
|
||||
"default_class": 1,
|
||||
}
|
||||
|
||||
elif dataset == 'toxic':
|
||||
elif dataset == "toxic":
|
||||
idx2class = ["non_toxic", "toxic"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
@ -368,37 +371,29 @@ def train_discriminator(
|
||||
cached_mode=cached
|
||||
).to(device)
|
||||
|
||||
with open("datasets/toxic/toxic_train.txt") as f:
|
||||
data = []
|
||||
for i, line in enumerate(f):
|
||||
try:
|
||||
data.append(eval(line))
|
||||
except:
|
||||
print('Error evaluating line {}: {}'.format(
|
||||
i, line
|
||||
))
|
||||
continue
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for i, d in enumerate(data):
|
||||
try:
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
with open("datasets/toxic/toxic_train.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(int(np.sum(d['label']) > 0))
|
||||
except:
|
||||
print("Error tokenizing line {}, skipping it".format(i))
|
||||
pass
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(int(np.sum(d["label"]) > 0))
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
@ -415,18 +410,18 @@ def train_discriminator(
|
||||
"default_class": 0,
|
||||
}
|
||||
|
||||
else: # if dataset == 'generic':
|
||||
else: # if dataset == "generic":
|
||||
# This assumes the input dataset is a TSV with the following structure:
|
||||
# class \t text
|
||||
|
||||
if dataset_fp is None:
|
||||
raise ValueError('When generic dataset is selected, '
|
||||
'dataset_fp needs to be specified aswell.')
|
||||
raise ValueError("When generic dataset is selected, "
|
||||
"dataset_fp needs to be specified aswell.")
|
||||
|
||||
classes = set()
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter='\t')
|
||||
for row in csv_reader:
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for row in tqdm(csv_reader, ascii=True):
|
||||
if row:
|
||||
classes.add(row[0])
|
||||
|
||||
@ -442,8 +437,8 @@ def train_discriminator(
|
||||
x = []
|
||||
y = []
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter='\t')
|
||||
for i, row in enumerate(csv_reader):
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for i, row in enumerate(tqdm(csv_reader, ascii=True)):
|
||||
if row:
|
||||
label = row[0]
|
||||
text = row[1]
|
||||
@ -458,9 +453,10 @@ def train_discriminator(
|
||||
)
|
||||
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
print(
|
||||
"Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
|
||||
x.append(seq)
|
||||
@ -487,12 +483,14 @@ def train_discriminator(
|
||||
}
|
||||
|
||||
end = time.time()
|
||||
print('Preprocessed {} data points'.format(
|
||||
print("Preprocessed {} data points".format(
|
||||
len(train_dataset) + len(test_dataset))
|
||||
)
|
||||
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||
|
||||
if cached:
|
||||
print("Building representation cache...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_loader = get_cached_data_loader(
|
||||
@ -524,7 +522,7 @@ def train_discriminator(
|
||||
|
||||
for epoch in range(epochs):
|
||||
start = time.time()
|
||||
print('\nEpoch', epoch + 1)
|
||||
print("\nEpoch", epoch + 1)
|
||||
|
||||
train_epoch(
|
||||
discriminator=discriminator,
|
||||
@ -553,31 +551,31 @@ def train_discriminator(
|
||||
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train a discriminator on top of GPT-2 representations')
|
||||
parser.add_argument('--dataset', type=str, default='SST',
|
||||
choices=('SST', 'clickbait', 'toxic', 'generic'),
|
||||
help='dataset to train the discriminator on.'
|
||||
'In case of generic, the dataset is expected'
|
||||
'to be a TSBV file with structure: class \\t text')
|
||||
parser.add_argument('--dataset_fp', type=str, default='',
|
||||
help='File path of the dataset to use. '
|
||||
'Needed only in case of generic datadset')
|
||||
parser.add_argument('--pretrained_model', type=str, default='gpt2-medium',
|
||||
help='Pretrained model to use as encoder')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='Number of training epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--save_model', action='store_true',
|
||||
help='whether to save the model')
|
||||
parser.add_argument('--cached', action='store_true',
|
||||
help='whether to cache the input representations')
|
||||
parser.add_argument('--no_cuda', action='store_true',
|
||||
help='use to turn off cuda')
|
||||
description="Train a discriminator on top of GPT-2 representations")
|
||||
parser.add_argument("--dataset", type=str, default="SST",
|
||||
choices=("SST", "clickbait", "toxic", "generic"),
|
||||
help="dataset to train the discriminator on."
|
||||
"In case of generic, the dataset is expected"
|
||||
"to be a TSBV file with structure: class \\t text")
|
||||
parser.add_argument("--dataset_fp", type=str, default="",
|
||||
help="File path of the dataset to use. "
|
||||
"Needed only in case of generic datadset")
|
||||
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
|
||||
help="Pretrained model to use as encoder")
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N",
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
|
||||
help="how many batches to wait before logging training status")
|
||||
parser.add_argument("--save_model", action="store_true",
|
||||
help="whether to save the model")
|
||||
parser.add_argument("--cached", action="store_true",
|
||||
help="whether to cache the input representations")
|
||||
parser.add_argument("--no_cuda", action="store_true",
|
||||
help="use to turn off cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_discriminator(**(vars(args)))
|
||||
|
Loading…
Reference in New Issue
Block a user