Add Information Gain Filtration algorithm (#16953)

* Add information gain filtration algorithm

* Complying with black requirements

* Added author

* Fixed import order

* flake8 corrections

Co-authored-by: Javier Turek <javier.turek@intel.com>
This commit is contained in:
mraunak 2022-05-18 10:39:02 -04:00 committed by GitHub
parent 91ede485a7
commit 5fdb54ece7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 963 additions and 0 deletions

View File

@ -0,0 +1,100 @@
# Information Gain Filtration(IGF)
Authors @Tuko @mraunak
This folder contains the code how to implement IGF for finetuning on GPT-2.
## What is IGF?
Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
a secondary model (e.g., a simple convolutional network) to predict the amount of information
gained over a given pre-trained model. The secondary model is lightweight and trained to
predict the Information Gain measure. Information Gain is defined as the change in a loss
function for a model before and after an SGD update with a sample (Equation X in the paper).
A small subset of the training set named the “objective” set, is used to measure information
gain on the pre-trained model, and consequently to train the secondary model. After
training, the model is used for filtering samples for the fine-tuning process. Therefore,
a high information gain value would suggest a sample is informative, whereas a low value
would suggest a non-informative sample that should be filtered out. Thus, a thresholding
strategy is defined to select informative samples. With such a strategy, samples are filtered
and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
step is applied. The filtration process is repeated until the fine-tuning process is over.
Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
# Results
Several experiments were conducted to show the robustness of the IGF method versus the
standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
implemented using the Transformers library and Pytorch. While the method may seem more
expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
![IGF performance](result_igf.png)
Figure 1: Comparing IGF to Standard Fine-tuning:
IGF with constant (p < 103 , t-test) and shifting(p < 106 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote ± one standard error). The right-hand figure shows the perplexity of each
method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
## How to use this project?
To fine-tune a transformer model with IGF on a language modeling task, use the following script:
- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
train_dataset and test_dataset
- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded.
- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
- `min_len`: The minimum length of the article to be used as objective set
- `trim`: Truncate the example if it exceeds context length
- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
- `max_steps`: To calculate training epochs
- `number`: The number of examples split to be used as objective_set/test_data
- `secondary_learner_batch_size`: The batch size of training data for secondary learner
- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
- `eval_interval`: Decay the selectivity of our secondary learner filter from"
1 standard deviation above average to 1 below average after eval_interval(10) batches"
```python
python run_clm_igf.py\
--model_name_or_path "gpt2" \
--data_file="data/tokenized_stories_train_wikitext103" \
--igf_data_file="data/IGF_values" \
--context_len 32 \
--size_objective_set 100 \
--min_len 1026 \
--trim True \
--eval_freq 100 \
--max_steps 1000 \
--secondary_learner_batch_size 128 \
--secondary_learner_max_epochs 15 \
--number 100 \
--recopy_model \
--eval_interval 10 \
```
## Citation
If you find the resource useful, please cite the following paper
```
@inproceedings{antonello-etal-2021-selecting,
title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
month = aug,
year = "2021",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.acl-long.87",
doi = "10.18653/v1/2021.acl-long.87",
pages = "1072--1085",
}
```

View File

@ -0,0 +1,419 @@
# Copyright 2022 - Intel Corp. All rights reserved.
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
import copy
import logging
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import joblib
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
logger = logging.getLogger(__name__)
def set_seed(seed):
"""
For reproducible training
Args:
seed: A seed for reproducible training
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def compute_perplexity(model, test_data, context_len):
"""
Computes perplexity of the transformer model on data in test_data
Args:
model: Pre-trained GPT2 model
test_data: Data on which perplexity calculation is required
context_len: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded
Returns:
Perplexity on input test data
"""
model.eval()
device = next(model.parameters()).device
eval_batch_size = 1
context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
eval_loss = torch.zeros(1, device=device)
nb_eval_examples = 0
for batch in eval_dataloader:
batch.to(device)
# pad
context.zero_()
for i in range(eval_batch_size):
context[i, :] = batch[i]
outputs = model(context, labels=context)
eval_loss += outputs[0].sum().item()
nb_eval_examples += batch.size(0)
eval_loss = eval_loss / nb_eval_examples
perplexity = torch.exp(eval_loss)
model.train()
return perplexity
def load_gpt2(model_name="gpt2"):
"""
load original gpt2 and save off for quicker loading
Args:
model_name: GPT-2
Returns:
GPT-2 model
"""
model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
torch.save(model.state_dict(), model_name + "local.pt")
return model
def recopy_gpt2(orig_model, device, max_steps):
"""
Reset the model to the original pretrained GPT-2 weights after each iteration
Args:
orig_model: Original pretrained GPT-2 model imported from Transformers library
device: CPU/GPU
max_steps: number of training steps
Returns:
Original PreTrained GPT-2 model,
lm_optimizer: Adam optimizer with Decoupled weight decay
lm_scheduler: linear scheduler with the appropriate schedule
"""
model = copy.deepcopy(orig_model)
model.to(device)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
torch.cuda.empty_cache()
return model, lm_optimizer, lm_scheduler
def intermittent_save(contexts, real_perps, past_perps, filename):
"""
save the perplexity differences to filename
Args:
contexts: Example on which the perplexity is calculated
real_perps: Perplexity after back-propagating on the selected context
past_perps: Perplexity of model before training on the context
filename: File to store perplexity differences
Returns:
file with perplexity differences
"""
# save the perplexity differences to filename
avg = np.array(real_perps).mean()
std = np.array(real_perps).std()
perp_diff = (real_perps - avg) / std
data_final = list(zip(contexts, perp_diff, past_perps))
joblib.dump(data_final, filename)
def collect_objective_set(
model,
orig_perp,
context_len,
train_data,
objective_set,
max_steps,
device,
filename="dev.jbl",
recopy_model=recopy_gpt2,
):
"""
Collect individual IGF values from pre-trained transformer model
max_steps samples of training data to train secondary model
Args:
model: Pre-trained GPT2 model
orig_perp: Perplexity of original pretrained GPT-2 model
context_len: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded
train_data: Data to train model
objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
max_steps: To calculate training epochs of model
device: GPU/CPU
filename: To store intermediate perplexity differences
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
Returns:
file stored intermediate perplexity differences in intermediate stages
"""
# initialize variables to record relevant information
contexts = []
real_perps = []
past_perps = []
# Initialize the transformer model
orig_model = copy.deepcopy(model)
orig_model.to(device="cpu")
torch.cuda.empty_cache()
# Compute perplexity of initial transformer model for comparison
model.train()
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
for step in tqdm(range(max_steps)):
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
story = random.choice(train_data)
start = random.randint(0, len(story[0]) - context_len - 1)
context[0, :] = story[0][start : start + context_len]
lm_optimizer.zero_grad()
outputs = model(context, labels=context)
lm_loss = outputs[0]
past_perp = compute_perplexity(model, context, context_len)
model.train()
lm_loss.backward()
# Do LM backprop
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
lm_optimizer.step()
lm_scheduler.step() # Update learning rate schedule
# Compute perplexity after back-propagating on the selected context
real_perp = compute_perplexity(model, objective_set, context_len)
# Periodically save the stored (X, IG(X)) pairs
if step % 1000 == 0 and step > 1:
intermittent_save(contexts, real_perps, past_perps, filename)
# Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
past_perps.append(past_perp.item())
real_perps.append(orig_perp - real_perp.item())
contexts.append(np.array(context.cpu()))
intermittent_save(contexts, real_perps, past_perps, filename)
def generate_datasets(
context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
):
"""
Generate objective set and training set
Args:
context_len: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded
file: Tokenized data split into training set and objective set
number: size of objective dataset
min_len: minimum length of a context in objective set
trim: If True truncate the context if it exceeds context length
Returns:
Generated objective set and training data
"""
# Generate objective set and training set
# Designate the first number (100) articles that are long enough to be used
# as our objective set, rest (that are long enough) are training data for
# secondary learner
data = joblib.load(file)
print("data loaded")
objective_set = []
if trim:
for i, example in enumerate(data):
if len(example[0]) > min_len:
start = random.randint(0, len(example[0]) - context_len - 1)
objective_set.append(example[0, start : start + context_len])
if len(objective_set) >= number:
break
train_data = []
for j in range(i + 1, len(data)):
if len(data[j][0]) > min_len:
train_data.append(data[j])
else:
objective_set = data[0:number]
train_data = data[number:]
joblib.dump(objective_set, "objective_set.jbl")
print("objective set saved")
return train_data, objective_set
def train_secondary_learner(
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
):
"""
Train the secondary learner (igf_model)
Args:
secondary_learner: secondary learner
train_dataset: data to train secondary learner
max_epochs: number of epochs to train secondary learner
batch_size: batch size of training data of secondary learner
eval_freq: secondary model evaluation can be triggered at eval_freq
igf_model_path: path to store trained secondary learner
Returns:
Trained secondary learner
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# We will use the first 512 pairs from our dataset as a test set for
# our secondary learner and the rest to train
test_dataset = train_dataset[:512]
train_dataset = train_dataset[512:]
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
# secondary learner model set up
loss = nn.MSELoss()
test_loss = nn.MSELoss(reduction="sum")
secondary_learner.to(device)
q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
secondary_learner.train()
# TODO in original code this is written as number of actual batches seen
# not number of items seen but other places it is number of items instead.
# improve consistency! changed this to epochs for clarity
best_test_loss = float("inf")
# Iterate through batches until we've used max_steps batches
for epoch in range(int(max_epochs)):
tr_q_loss = 0.0
secondary_learner.train()
for step, batch in enumerate(train_dataloader):
context = batch[0].to(device)
real_q = batch[1].to(device)
predicted_q = secondary_learner(context)
q_optimizer.zero_grad()
q_loss = loss(predicted_q, real_q.float())
q_loss.backward()
q_optimizer.step()
tr_q_loss += q_loss.item()
# model trains fairly quickly so we won't wait for a full epoch
# eval is triggered at eval_freq and end of epochs
if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
tr_loss = tr_q_loss / (step + 1)
secondary_learner.eval()
q_loss2 = 0.0
sum_q2 = 0.0
predicted = []
actual = []
# Compute performance of the secondary learner after this batch
for step2, batch2 in enumerate(test_dataloader):
features2 = batch2[0].to(device)
real_q2 = batch2[1].to(device)
predicted_q2 = secondary_learner(features2)
q_loss2 += test_loss(predicted_q2, real_q2).item()
sum_q2 += torch.sum(predicted_q2).item()
for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
predicted.append(i.item())
for ei, i in enumerate(real_q2.cpu().detach().numpy()):
actual.append(i.item())
q_loss2 /= len(test_dataset)
print(
"Epoch: ",
epoch,
"step: ",
step,
"Avg. q:",
sum_q2 / len(test_dataset),
"Train Loss: ",
tr_loss,
"Test Loss: ",
q_loss2,
)
if q_loss2 < best_test_loss:
joblib.dump((predicted, actual), "pred_vs_actual.jbl")
torch.save(secondary_learner.state_dict(), igf_model_path)
best_test_loss = q_loss2
secondary_learner.train()
return secondary_learner
class SecondaryLearner(nn.Module):
"""
Our secondary learner
"""
def __init__(self, model):
"""
We use a simple convolutional network as our secondary learner
Args:
model: Pre-trained GPT2 model
"""
# embeddings are from the pretrained model
super(SecondaryLearner, self).__init__()
self.embeddings = model.transformer.wte
self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
def forward(self, context):
"""
Forward pass through the secondary learner
Args:
context: Context input to the secondary learner
Returns:
tensor after squeeze operation
"""
pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
qs = self.fc(pooled)
return qs.squeeze(1)
@classmethod
def from_pretrained(cls, state_path, model):
"""
Load the secondary learner
Args:
state_path: Path to save secondary learner
model: Pretrained GPT-2
Returns:
secondary learner
"""
secondary_learner = cls(model) # this calls __init__
state_dict = torch.load(state_path)
secondary_learner.load_state_dict(state_dict)
secondary_learner.embeddings = model.transformer.wte
secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
return secondary_learner

View File

@ -0,0 +1,6 @@
matplotlib
numpy>=1.17.2
joblib>=0.13.2
scipy
torch>=1.10.1
transformers>=3.5

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@ -0,0 +1,438 @@
# Copyright 2022 - Intel Corp. All rights reserved.
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
"""
Implementation of a new method for fine-tuning transformer models that we call
Information Gain Filtration 'IGF' on WikiText data set and compared the results
with the standard fine-tuning method
Steps followed in the code:
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
Our IG (information gain) model is learning to predict the informativeness of a particular
context. Informativeness is the change in metric between the models accuracy on an
objective set before and after seeing that context. For casual language modeling, the
metric is perplexity.
2) A secondary learner is trained to infer a function approximation for IG using the dataset
created in (1).
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
"""
# Prerequisite libraries:
import argparse
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
import joblib
from igf.igf import (
SecondaryLearner,
collect_objective_set,
compute_perplexity,
generate_datasets,
load_gpt2,
recopy_gpt2,
set_seed,
train_secondary_learner,
)
from transformers import GPT2LMHeadModel
def generate_n_pairs(
context_len=32,
max_steps=10,
size_objective_set=100,
min_len=1026,
trim=True,
data_file="data/tokenized_stories_train_wikitext103.jbl",
igf_data_file="igf_context_pairs.jbl",
):
"""
Collecting *n* pairs for training the secondary learner
Args:
context_len: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded
max_steps: To calculate training epochs of secondary learner
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
min_len: The minimum length of the article to be used as objective set
trim: If True truncate the context if it exceeds context length
data_file: Tokenized data set split for training and evaluation of model
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
Returns:
Data stored in igf_data_file
"""
# generates same data everytime
set_seed(3)
# generate train_data and objective_set
train_data, objective_set = generate_datasets(
context_len, data_file, number=size_objective_set, min_len=1026, trim=True
)
# keeps model same across runs
set_seed(4)
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
# can we train on GPU?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrained model
model = load_gpt2("gpt2").to(device)
print("computing perplexity on objective set")
orig_perp = compute_perplexity(model, objective_set, context_len).item()
print("perplexity on objective set:", orig_perp)
# collect igf pairs and save to file demo.jbl
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
# clean up, delete model and data we don't need anymore
del model, train_data, objective_set
torch.cuda.empty_cache()
def training_secondary_learner(
secondary_learner_train_data,
secondary_learner_max_epochs=15,
secondary_learner_batch_size=128,
eval_freq=100,
igf_model_path="igf_model.pt",
):
"""
Train the secondary learner
Args:
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
secondary_learner_max_epochs: Number of epochs to train secondary learner
secondary_learner_batch_size: Batch size to train secondary learner
eval_freq (object): secondary model evaluation can be triggered at eval_freq
igf_model_path: path to store trained secondary learner
Returns:
Trained secondary learner
"""
set_seed(42)
# Load pre-trained model
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Initialize secondary learner to use embedding weights of model
secondary_learner = SecondaryLearner(model)
# Train secondary learner
secondary_learner = train_secondary_learner(
secondary_learner,
secondary_learner_train_data,
max_epochs=secondary_learner_max_epochs,
batch_size=secondary_learner_batch_size,
eval_freq=100,
igf_model_path=igf_model_path,
)
del model, secondary_learner_train_data
torch.cuda.empty_cache()
return secondary_learner
def finetune(
model,
train_dataset,
test_dataset,
context_len=32,
max_steps=1000,
batch_size=16,
threshold=1.0,
recopy_model=recopy_gpt2,
secondary_learner=None,
eval_interval=10,
finetuned_model_name="gpt2_finetuned.pt",
):
"""
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
Args:
model: pre-trained GPT-2 model
train_dataset: Data set to train GPT-2 model
test_dataset: Evaluate GPT-2 model
context_len: The maximum total input sequence length after tokenization. Sequences longer
than this will be truncated, sequences shorter will be padded
max_steps: To calculate training epochs
batch_size: Batch size to train GPT-2 model
threshold: The threshold value used by secondary learner to filter the train_data and allow only"
informative data as input to the model
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
secondary_learner: Selection of IGF as fine-tuning method if not None
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
1 standard deviation above average to 1 below average
fine-tuned_model_name: name of the final final-tuned GPT-2 model
Returns:
Fine-tuned GPT-2 model
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
num_train_epochs = max_steps // (len(train_dataset)) + 1
global_step = 0
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
model.train()
if secondary_learner is not None:
secondary_learner.to(device)
secondary_learner.eval()
contexts = []
examples = 0
observed_qs = []
test_perps = []
# Compute the performance of the transformer model at the beginning
real_perp = compute_perplexity(model, test_dataset, context_len)
test_perps.append(real_perp)
print("Test perplexity, step", global_step, ":", real_perp)
for epoch in range(int(num_train_epochs)):
for step, example in enumerate(train_dataloader):
torch.cuda.empty_cache()
start = random.randint(0, example.size(2) - context_len - 1)
context[0, :] = example[0, 0, start : start + context_len]
lm_optimizer.zero_grad()
outputs = model(context, labels=context)
do_backprop = True
if secondary_learner is not None:
predicted_q = secondary_learner.forward(
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
)[0].item()
observed_qs.append(float(predicted_q))
# Here we implement the simple non-constant threshold for the predicted IG(X) value
# We will decay the selectivity of our secondary learner filter from
# 1 standard deviation above average to 1 below average after 10 batches.
if global_step == 10:
threshold = -1
if predicted_q < threshold:
do_backprop = False
# If we passed the filter, add the context to the batch!
if do_backprop:
contexts.append(np.array(context.cpu()))
lm_loss = outputs[0]
lm_loss.backward()
examples += 1
del outputs
# Once the batch is filled with enough contexts, backprop on the batch.
if examples == batch_size:
torch.cuda.empty_cache()
examples = 0
# Do LM backprop
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
lm_optimizer.step()
lm_scheduler.step() # Update learning rate schedule
global_step += 1
# Compute the performance of the transformer model at this batch
if global_step % eval_interval == 0:
real_perp = compute_perplexity(model, test_dataset, context_len)
test_perps.append(real_perp)
print("Test perplexity, step", global_step, ":", real_perp)
# Break out of the loop after 60 batches
if max_steps > 0 and global_step > 60:
break
if max_steps > 0 and global_step > 60:
break
# save finetuned transformer model
torch.save(model.state_dict(), finetuned_model_name)
torch.cuda.empty_cache()
# Do some cleaning up so we can reinitialize for the next run of this function
del lm_optimizer
del lm_scheduler
return model
def main():
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain data files for WikiText.",
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--data_file",
type=str,
default=None,
help="A jbl file containing tokenized data which can be split as objective dataset, "
"train_dataset and test_dataset.",
)
parser.add_argument(
"--igf_data_file",
type=str,
default=None,
help="A jbl file containing the context and information gain pairs to train secondary learner.",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the final fine-tuned model is stored.",
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--context_len",
default=32,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--size_objective_set",
default=100,
type=int,
help="number of articles that are long enough to be used as our objective set",
)
parser.add_argument(
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
)
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
parser.add_argument(
"--secondary_learner_batch_size",
default=128,
type=int,
help="batch size of training data for secondary learner",
)
parser.add_argument(
"--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
)
parser.add_argument(
"--eval_interval",
default=10,
type=int,
help="decay the selectivity of our secondary learner filter from"
"1 standard deviation above average to 1 below average after 10 batches",
)
parser.add_argument(
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
)
parser.add_argument(
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
)
parser.add_argument(
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
)
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
parser.add_argument(
"--threshold",
default=1.0,
type=float,
help="The threshold value used by secondary learner to filter the train_data and allow only"
" informative data as input to the model",
)
parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")
parser.add_argument(
"--recopy_model",
default=recopy_gpt2,
type=str,
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
)
# function calls
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
generate_n_pairs(
context_len=32,
max_steps=10,
size_objective_set=100,
min_len=1026,
trim=True,
data_file="data/tokenized_stories_train_wikitext103.jbl",
igf_data_file="igf_context_pairs.jbl",
)
# Load train data for secondary learner
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
# Train secondary learner
secondary_learner = training_secondary_learner(
secondary_learner_train_data,
secondary_learner_max_epochs=15,
secondary_learner_batch_size=128,
eval_freq=100,
igf_model_path="igf_model.pt",
)
# load pretrained gpt2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")
set_seed(42)
# Generate train and test data to train and evaluate gpt2 model
train_dataset, test_dataset = generate_datasets(
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
)
# fine-tuning of the gpt2 model using igf (Information Gain Filtration)
finetune(
model,
train_dataset,
test_dataset,
context_len=32,
max_steps=1000,
batch_size=16,
threshold=1.0,
recopy_model=recopy_gpt2,
secondary_learner=secondary_learner,
eval_interval=10,
finetuned_model_name="gpt2_finetuned.pt",
)
if __name__ == "__main__":
main()