mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Information Gain Filtration algorithm (#16953)
* Add information gain filtration algorithm * Complying with black requirements * Added author * Fixed import order * flake8 corrections Co-authored-by: Javier Turek <javier.turek@intel.com>
This commit is contained in:
parent
91ede485a7
commit
5fdb54ece7
100
examples/research_projects/information-gain-filtration/README.md
Normal file
100
examples/research_projects/information-gain-filtration/README.md
Normal file
@ -0,0 +1,100 @@
|
||||
|
||||
# Information Gain Filtration(IGF)
|
||||
|
||||
Authors @Tuko @mraunak
|
||||
|
||||
This folder contains the code how to implement IGF for finetuning on GPT-2.
|
||||
|
||||
## What is IGF?
|
||||
|
||||
Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
|
||||
performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
|
||||
a secondary model (e.g., a simple convolutional network) to predict the amount of information
|
||||
gained over a given pre-trained model. The secondary model is lightweight and trained to
|
||||
predict the Information Gain measure. Information Gain is defined as the change in a loss
|
||||
function for a model before and after an SGD update with a sample (Equation X in the paper).
|
||||
A small subset of the training set named the “objective” set, is used to measure information
|
||||
gain on the pre-trained model, and consequently to train the secondary model. After
|
||||
training, the model is used for filtering samples for the fine-tuning process. Therefore,
|
||||
a high information gain value would suggest a sample is informative, whereas a low value
|
||||
would suggest a non-informative sample that should be filtered out. Thus, a thresholding
|
||||
strategy is defined to select informative samples. With such a strategy, samples are filtered
|
||||
and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
|
||||
step is applied. The filtration process is repeated until the fine-tuning process is over.
|
||||
|
||||
Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
|
||||
|
||||
# Results
|
||||
|
||||
Several experiments were conducted to show the robustness of the IGF method versus the
|
||||
standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
|
||||
Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
|
||||
implemented using the Transformers library and Pytorch. While the method may seem more
|
||||
expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
|
||||
|
||||

|
||||
|
||||
Figure 1: Comparing IGF to Standard Fine-tuning:
|
||||
IGF with constant (p < 10−3 , t-test) and shifting(p < 10−6 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
|
||||
test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote ± one standard error). The right-hand figure shows the perplexity of each
|
||||
method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
|
||||
|
||||
## How to use this project?
|
||||
|
||||
To fine-tune a transformer model with IGF on a language modeling task, use the following script:
|
||||
|
||||
- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
|
||||
- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
|
||||
train_dataset and test_dataset
|
||||
- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
|
||||
- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded.
|
||||
- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
|
||||
- `min_len`: The minimum length of the article to be used as objective set
|
||||
- `trim`: Truncate the example if it exceeds context length
|
||||
- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
|
||||
- `max_steps`: To calculate training epochs
|
||||
- `number`: The number of examples split to be used as objective_set/test_data
|
||||
- `secondary_learner_batch_size`: The batch size of training data for secondary learner
|
||||
- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
|
||||
- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
- `eval_interval`: Decay the selectivity of our secondary learner filter from"
|
||||
1 standard deviation above average to 1 below average after eval_interval(10) batches"
|
||||
|
||||
|
||||
```python
|
||||
python run_clm_igf.py\
|
||||
--model_name_or_path "gpt2" \
|
||||
--data_file="data/tokenized_stories_train_wikitext103" \
|
||||
--igf_data_file="data/IGF_values" \
|
||||
--context_len 32 \
|
||||
--size_objective_set 100 \
|
||||
--min_len 1026 \
|
||||
--trim True \
|
||||
--eval_freq 100 \
|
||||
--max_steps 1000 \
|
||||
--secondary_learner_batch_size 128 \
|
||||
--secondary_learner_max_epochs 15 \
|
||||
--number 100 \
|
||||
--recopy_model \
|
||||
--eval_interval 10 \
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you find the resource useful, please cite the following paper
|
||||
|
||||
```
|
||||
@inproceedings{antonello-etal-2021-selecting,
|
||||
title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
|
||||
author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
|
||||
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
|
||||
month = aug,
|
||||
year = "2021",
|
||||
address = "Online",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://aclanthology.org/2021.acl-long.87",
|
||||
doi = "10.18653/v1/2021.acl-long.87",
|
||||
pages = "1072--1085",
|
||||
}
|
||||
```
|
@ -0,0 +1,419 @@
|
||||
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import joblib
|
||||
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
For reproducible training
|
||||
|
||||
Args:
|
||||
seed: A seed for reproducible training
|
||||
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def compute_perplexity(model, test_data, context_len):
|
||||
"""
|
||||
Computes perplexity of the transformer model on data in test_data
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
test_data: Data on which perplexity calculation is required
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
|
||||
Returns:
|
||||
Perplexity on input test data
|
||||
|
||||
"""
|
||||
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
eval_batch_size = 1
|
||||
context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
|
||||
eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
|
||||
eval_loss = torch.zeros(1, device=device)
|
||||
nb_eval_examples = 0
|
||||
for batch in eval_dataloader:
|
||||
batch.to(device)
|
||||
# pad
|
||||
context.zero_()
|
||||
for i in range(eval_batch_size):
|
||||
context[i, :] = batch[i]
|
||||
outputs = model(context, labels=context)
|
||||
eval_loss += outputs[0].sum().item()
|
||||
nb_eval_examples += batch.size(0)
|
||||
eval_loss = eval_loss / nb_eval_examples
|
||||
perplexity = torch.exp(eval_loss)
|
||||
model.train()
|
||||
return perplexity
|
||||
|
||||
|
||||
def load_gpt2(model_name="gpt2"):
|
||||
"""
|
||||
load original gpt2 and save off for quicker loading
|
||||
|
||||
Args:
|
||||
model_name: GPT-2
|
||||
|
||||
Returns:
|
||||
GPT-2 model
|
||||
|
||||
"""
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
|
||||
torch.save(model.state_dict(), model_name + "local.pt")
|
||||
return model
|
||||
|
||||
|
||||
def recopy_gpt2(orig_model, device, max_steps):
|
||||
"""
|
||||
Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
|
||||
Args:
|
||||
orig_model: Original pretrained GPT-2 model imported from Transformers library
|
||||
device: CPU/GPU
|
||||
max_steps: number of training steps
|
||||
|
||||
Returns:
|
||||
Original PreTrained GPT-2 model,
|
||||
lm_optimizer: Adam optimizer with Decoupled weight decay
|
||||
lm_scheduler: linear scheduler with the appropriate schedule
|
||||
|
||||
"""
|
||||
model = copy.deepcopy(orig_model)
|
||||
model.to(device)
|
||||
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
||||
]
|
||||
lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
|
||||
lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
|
||||
torch.cuda.empty_cache()
|
||||
return model, lm_optimizer, lm_scheduler
|
||||
|
||||
|
||||
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||
|
||||
"""
|
||||
save the perplexity differences to filename
|
||||
|
||||
Args:
|
||||
contexts: Example on which the perplexity is calculated
|
||||
real_perps: Perplexity after back-propagating on the selected context
|
||||
past_perps: Perplexity of model before training on the context
|
||||
filename: File to store perplexity differences
|
||||
|
||||
Returns:
|
||||
file with perplexity differences
|
||||
|
||||
"""
|
||||
# save the perplexity differences to filename
|
||||
avg = np.array(real_perps).mean()
|
||||
std = np.array(real_perps).std()
|
||||
perp_diff = (real_perps - avg) / std
|
||||
data_final = list(zip(contexts, perp_diff, past_perps))
|
||||
joblib.dump(data_final, filename)
|
||||
|
||||
|
||||
def collect_objective_set(
|
||||
model,
|
||||
orig_perp,
|
||||
context_len,
|
||||
train_data,
|
||||
objective_set,
|
||||
max_steps,
|
||||
device,
|
||||
filename="dev.jbl",
|
||||
recopy_model=recopy_gpt2,
|
||||
):
|
||||
|
||||
"""
|
||||
Collect individual IGF values from pre-trained transformer model
|
||||
max_steps samples of training data to train secondary model
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
orig_perp: Perplexity of original pretrained GPT-2 model
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
train_data: Data to train model
|
||||
objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||
max_steps: To calculate training epochs of model
|
||||
device: GPU/CPU
|
||||
filename: To store intermediate perplexity differences
|
||||
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
|
||||
Returns:
|
||||
file stored intermediate perplexity differences in intermediate stages
|
||||
|
||||
"""
|
||||
|
||||
# initialize variables to record relevant information
|
||||
contexts = []
|
||||
real_perps = []
|
||||
past_perps = []
|
||||
|
||||
# Initialize the transformer model
|
||||
orig_model = copy.deepcopy(model)
|
||||
orig_model.to(device="cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compute perplexity of initial transformer model for comparison
|
||||
model.train()
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||
|
||||
for step in tqdm(range(max_steps)):
|
||||
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||
story = random.choice(train_data)
|
||||
start = random.randint(0, len(story[0]) - context_len - 1)
|
||||
context[0, :] = story[0][start : start + context_len]
|
||||
lm_optimizer.zero_grad()
|
||||
outputs = model(context, labels=context)
|
||||
lm_loss = outputs[0]
|
||||
past_perp = compute_perplexity(model, context, context_len)
|
||||
model.train()
|
||||
lm_loss.backward()
|
||||
# Do LM backprop
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||
lm_optimizer.step()
|
||||
lm_scheduler.step() # Update learning rate schedule
|
||||
|
||||
# Compute perplexity after back-propagating on the selected context
|
||||
real_perp = compute_perplexity(model, objective_set, context_len)
|
||||
|
||||
# Periodically save the stored (X, IG(X)) pairs
|
||||
if step % 1000 == 0 and step > 1:
|
||||
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||
|
||||
# Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
|
||||
|
||||
past_perps.append(past_perp.item())
|
||||
real_perps.append(orig_perp - real_perp.item())
|
||||
contexts.append(np.array(context.cpu()))
|
||||
|
||||
intermittent_save(contexts, real_perps, past_perps, filename)
|
||||
|
||||
|
||||
def generate_datasets(
|
||||
context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||
):
|
||||
"""
|
||||
Generate objective set and training set
|
||||
|
||||
Args:
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
file: Tokenized data split into training set and objective set
|
||||
number: size of objective dataset
|
||||
min_len: minimum length of a context in objective set
|
||||
trim: If True truncate the context if it exceeds context length
|
||||
|
||||
Returns:
|
||||
Generated objective set and training data
|
||||
|
||||
|
||||
"""
|
||||
# Generate objective set and training set
|
||||
# Designate the first number (100) articles that are long enough to be used
|
||||
# as our objective set, rest (that are long enough) are training data for
|
||||
# secondary learner
|
||||
|
||||
data = joblib.load(file)
|
||||
print("data loaded")
|
||||
objective_set = []
|
||||
if trim:
|
||||
for i, example in enumerate(data):
|
||||
if len(example[0]) > min_len:
|
||||
start = random.randint(0, len(example[0]) - context_len - 1)
|
||||
objective_set.append(example[0, start : start + context_len])
|
||||
if len(objective_set) >= number:
|
||||
break
|
||||
train_data = []
|
||||
for j in range(i + 1, len(data)):
|
||||
if len(data[j][0]) > min_len:
|
||||
train_data.append(data[j])
|
||||
else:
|
||||
objective_set = data[0:number]
|
||||
train_data = data[number:]
|
||||
|
||||
joblib.dump(objective_set, "objective_set.jbl")
|
||||
print("objective set saved")
|
||||
return train_data, objective_set
|
||||
|
||||
|
||||
def train_secondary_learner(
|
||||
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||
):
|
||||
|
||||
"""
|
||||
Train the secondary learner (igf_model)
|
||||
|
||||
Args:
|
||||
secondary_learner: secondary learner
|
||||
train_dataset: data to train secondary learner
|
||||
max_epochs: number of epochs to train secondary learner
|
||||
batch_size: batch size of training data of secondary learner
|
||||
eval_freq: secondary model evaluation can be triggered at eval_freq
|
||||
igf_model_path: path to store trained secondary learner
|
||||
|
||||
Returns:
|
||||
Trained secondary learner
|
||||
|
||||
"""
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# We will use the first 512 pairs from our dataset as a test set for
|
||||
# our secondary learner and the rest to train
|
||||
test_dataset = train_dataset[:512]
|
||||
train_dataset = train_dataset[512:]
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
|
||||
|
||||
# secondary learner model set up
|
||||
loss = nn.MSELoss()
|
||||
test_loss = nn.MSELoss(reduction="sum")
|
||||
secondary_learner.to(device)
|
||||
q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
|
||||
secondary_learner.train()
|
||||
|
||||
# TODO in original code this is written as number of actual batches seen
|
||||
# not number of items seen but other places it is number of items instead.
|
||||
# improve consistency! changed this to epochs for clarity
|
||||
best_test_loss = float("inf")
|
||||
# Iterate through batches until we've used max_steps batches
|
||||
for epoch in range(int(max_epochs)):
|
||||
tr_q_loss = 0.0
|
||||
secondary_learner.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
context = batch[0].to(device)
|
||||
real_q = batch[1].to(device)
|
||||
predicted_q = secondary_learner(context)
|
||||
q_optimizer.zero_grad()
|
||||
q_loss = loss(predicted_q, real_q.float())
|
||||
q_loss.backward()
|
||||
q_optimizer.step()
|
||||
tr_q_loss += q_loss.item()
|
||||
|
||||
# model trains fairly quickly so we won't wait for a full epoch
|
||||
# eval is triggered at eval_freq and end of epochs
|
||||
if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
|
||||
tr_loss = tr_q_loss / (step + 1)
|
||||
|
||||
secondary_learner.eval()
|
||||
q_loss2 = 0.0
|
||||
sum_q2 = 0.0
|
||||
predicted = []
|
||||
actual = []
|
||||
# Compute performance of the secondary learner after this batch
|
||||
for step2, batch2 in enumerate(test_dataloader):
|
||||
features2 = batch2[0].to(device)
|
||||
real_q2 = batch2[1].to(device)
|
||||
predicted_q2 = secondary_learner(features2)
|
||||
q_loss2 += test_loss(predicted_q2, real_q2).item()
|
||||
sum_q2 += torch.sum(predicted_q2).item()
|
||||
for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
|
||||
predicted.append(i.item())
|
||||
for ei, i in enumerate(real_q2.cpu().detach().numpy()):
|
||||
actual.append(i.item())
|
||||
|
||||
q_loss2 /= len(test_dataset)
|
||||
print(
|
||||
"Epoch: ",
|
||||
epoch,
|
||||
"step: ",
|
||||
step,
|
||||
"Avg. q:",
|
||||
sum_q2 / len(test_dataset),
|
||||
"Train Loss: ",
|
||||
tr_loss,
|
||||
"Test Loss: ",
|
||||
q_loss2,
|
||||
)
|
||||
if q_loss2 < best_test_loss:
|
||||
joblib.dump((predicted, actual), "pred_vs_actual.jbl")
|
||||
torch.save(secondary_learner.state_dict(), igf_model_path)
|
||||
best_test_loss = q_loss2
|
||||
|
||||
secondary_learner.train()
|
||||
return secondary_learner
|
||||
|
||||
|
||||
class SecondaryLearner(nn.Module):
|
||||
"""
|
||||
Our secondary learner
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""
|
||||
We use a simple convolutional network as our secondary learner
|
||||
|
||||
Args:
|
||||
model: Pre-trained GPT2 model
|
||||
"""
|
||||
# embeddings are from the pretrained model
|
||||
super(SecondaryLearner, self).__init__()
|
||||
self.embeddings = model.transformer.wte
|
||||
self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||
self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
|
||||
self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
|
||||
|
||||
def forward(self, context):
|
||||
"""
|
||||
Forward pass through the secondary learner
|
||||
|
||||
Args:
|
||||
context: Context input to the secondary learner
|
||||
|
||||
Returns:
|
||||
tensor after squeeze operation
|
||||
|
||||
"""
|
||||
pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
|
||||
qs = self.fc(pooled)
|
||||
return qs.squeeze(1)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, state_path, model):
|
||||
"""
|
||||
Load the secondary learner
|
||||
|
||||
Args:
|
||||
state_path: Path to save secondary learner
|
||||
model: Pretrained GPT-2
|
||||
|
||||
Returns:
|
||||
secondary learner
|
||||
"""
|
||||
|
||||
secondary_learner = cls(model) # this calls __init__
|
||||
state_dict = torch.load(state_path)
|
||||
secondary_learner.load_state_dict(state_dict)
|
||||
secondary_learner.embeddings = model.transformer.wte
|
||||
secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
|
||||
return secondary_learner
|
@ -0,0 +1,6 @@
|
||||
matplotlib
|
||||
numpy>=1.17.2
|
||||
joblib>=0.13.2
|
||||
scipy
|
||||
torch>=1.10.1
|
||||
transformers>=3.5
|
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
@ -0,0 +1,438 @@
|
||||
# Copyright 2022 - Intel Corp. All rights reserved.
|
||||
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
|
||||
|
||||
"""
|
||||
Implementation of a new method for fine-tuning transformer models that we call
|
||||
Information Gain Filtration 'IGF' on WikiText data set and compared the results
|
||||
with the standard fine-tuning method
|
||||
|
||||
Steps followed in the code:
|
||||
|
||||
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
|
||||
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular
|
||||
context. Informativeness is the change in metric between the model’s accuracy on an
|
||||
objective set before and after seeing that context. For casual language modeling, the
|
||||
metric is perplexity.
|
||||
|
||||
2) A secondary learner is trained to infer a function approximation for IG using the dataset
|
||||
created in (1).
|
||||
|
||||
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
|
||||
|
||||
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
|
||||
|
||||
"""
|
||||
|
||||
# Prerequisite libraries:
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
import joblib
|
||||
from igf.igf import (
|
||||
SecondaryLearner,
|
||||
collect_objective_set,
|
||||
compute_perplexity,
|
||||
generate_datasets,
|
||||
load_gpt2,
|
||||
recopy_gpt2,
|
||||
set_seed,
|
||||
train_secondary_learner,
|
||||
)
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
def generate_n_pairs(
|
||||
context_len=32,
|
||||
max_steps=10,
|
||||
size_objective_set=100,
|
||||
min_len=1026,
|
||||
trim=True,
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
):
|
||||
|
||||
"""
|
||||
Collecting *n* pairs for training the secondary learner
|
||||
Args:
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
max_steps: To calculate training epochs of secondary learner
|
||||
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||||
min_len: The minimum length of the article to be used as objective set
|
||||
trim: If True truncate the context if it exceeds context length
|
||||
data_file: Tokenized data set split for training and evaluation of model
|
||||
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
|
||||
|
||||
Returns:
|
||||
Data stored in igf_data_file
|
||||
|
||||
"""
|
||||
# generates same data everytime
|
||||
set_seed(3)
|
||||
# generate train_data and objective_set
|
||||
train_data, objective_set = generate_datasets(
|
||||
context_len, data_file, number=size_objective_set, min_len=1026, trim=True
|
||||
)
|
||||
# keeps model same across runs
|
||||
set_seed(4)
|
||||
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
|
||||
# can we train on GPU?
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# load pretrained model
|
||||
model = load_gpt2("gpt2").to(device)
|
||||
print("computing perplexity on objective set")
|
||||
orig_perp = compute_perplexity(model, objective_set, context_len).item()
|
||||
print("perplexity on objective set:", orig_perp)
|
||||
|
||||
# collect igf pairs and save to file demo.jbl
|
||||
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
|
||||
|
||||
# clean up, delete model and data we don't need anymore
|
||||
del model, train_data, objective_set
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def training_secondary_learner(
|
||||
secondary_learner_train_data,
|
||||
secondary_learner_max_epochs=15,
|
||||
secondary_learner_batch_size=128,
|
||||
eval_freq=100,
|
||||
igf_model_path="igf_model.pt",
|
||||
):
|
||||
"""
|
||||
Train the secondary learner
|
||||
|
||||
Args:
|
||||
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
|
||||
secondary_learner_max_epochs: Number of epochs to train secondary learner
|
||||
secondary_learner_batch_size: Batch size to train secondary learner
|
||||
eval_freq (object): secondary model evaluation can be triggered at eval_freq
|
||||
igf_model_path: path to store trained secondary learner
|
||||
|
||||
Returns:
|
||||
Trained secondary learner
|
||||
"""
|
||||
|
||||
set_seed(42)
|
||||
|
||||
# Load pre-trained model
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
|
||||
# Initialize secondary learner to use embedding weights of model
|
||||
secondary_learner = SecondaryLearner(model)
|
||||
|
||||
# Train secondary learner
|
||||
secondary_learner = train_secondary_learner(
|
||||
secondary_learner,
|
||||
secondary_learner_train_data,
|
||||
max_epochs=secondary_learner_max_epochs,
|
||||
batch_size=secondary_learner_batch_size,
|
||||
eval_freq=100,
|
||||
igf_model_path=igf_model_path,
|
||||
)
|
||||
|
||||
del model, secondary_learner_train_data
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return secondary_learner
|
||||
|
||||
|
||||
def finetune(
|
||||
model,
|
||||
train_dataset,
|
||||
test_dataset,
|
||||
context_len=32,
|
||||
max_steps=1000,
|
||||
batch_size=16,
|
||||
threshold=1.0,
|
||||
recopy_model=recopy_gpt2,
|
||||
secondary_learner=None,
|
||||
eval_interval=10,
|
||||
finetuned_model_name="gpt2_finetuned.pt",
|
||||
):
|
||||
"""
|
||||
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
|
||||
|
||||
Args:
|
||||
model: pre-trained GPT-2 model
|
||||
train_dataset: Data set to train GPT-2 model
|
||||
test_dataset: Evaluate GPT-2 model
|
||||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||||
than this will be truncated, sequences shorter will be padded
|
||||
max_steps: To calculate training epochs
|
||||
batch_size: Batch size to train GPT-2 model
|
||||
threshold: The threshold value used by secondary learner to filter the train_data and allow only"
|
||||
informative data as input to the model
|
||||
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||||
secondary_learner: Selection of IGF as fine-tuning method if not None
|
||||
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
|
||||
1 standard deviation above average to 1 below average
|
||||
fine-tuned_model_name: name of the final final-tuned GPT-2 model
|
||||
|
||||
Returns:
|
||||
Fine-tuned GPT-2 model
|
||||
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
|
||||
|
||||
num_train_epochs = max_steps // (len(train_dataset)) + 1
|
||||
global_step = 0
|
||||
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||||
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
|
||||
|
||||
model.train()
|
||||
if secondary_learner is not None:
|
||||
secondary_learner.to(device)
|
||||
secondary_learner.eval()
|
||||
contexts = []
|
||||
examples = 0
|
||||
|
||||
observed_qs = []
|
||||
test_perps = []
|
||||
|
||||
# Compute the performance of the transformer model at the beginning
|
||||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||
test_perps.append(real_perp)
|
||||
print("Test perplexity, step", global_step, ":", real_perp)
|
||||
for epoch in range(int(num_train_epochs)):
|
||||
for step, example in enumerate(train_dataloader):
|
||||
torch.cuda.empty_cache()
|
||||
start = random.randint(0, example.size(2) - context_len - 1)
|
||||
context[0, :] = example[0, 0, start : start + context_len]
|
||||
lm_optimizer.zero_grad()
|
||||
outputs = model(context, labels=context)
|
||||
do_backprop = True
|
||||
|
||||
if secondary_learner is not None:
|
||||
predicted_q = secondary_learner.forward(
|
||||
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
|
||||
)[0].item()
|
||||
observed_qs.append(float(predicted_q))
|
||||
|
||||
# Here we implement the simple non-constant threshold for the predicted IG(X) value
|
||||
# We will decay the selectivity of our secondary learner filter from
|
||||
# 1 standard deviation above average to 1 below average after 10 batches.
|
||||
|
||||
if global_step == 10:
|
||||
threshold = -1
|
||||
if predicted_q < threshold:
|
||||
do_backprop = False
|
||||
|
||||
# If we passed the filter, add the context to the batch!
|
||||
if do_backprop:
|
||||
contexts.append(np.array(context.cpu()))
|
||||
lm_loss = outputs[0]
|
||||
lm_loss.backward()
|
||||
examples += 1
|
||||
|
||||
del outputs
|
||||
|
||||
# Once the batch is filled with enough contexts, backprop on the batch.
|
||||
if examples == batch_size:
|
||||
torch.cuda.empty_cache()
|
||||
examples = 0
|
||||
# Do LM backprop
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||||
lm_optimizer.step()
|
||||
lm_scheduler.step() # Update learning rate schedule
|
||||
global_step += 1
|
||||
# Compute the performance of the transformer model at this batch
|
||||
if global_step % eval_interval == 0:
|
||||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||||
test_perps.append(real_perp)
|
||||
|
||||
print("Test perplexity, step", global_step, ":", real_perp)
|
||||
# Break out of the loop after 60 batches
|
||||
if max_steps > 0 and global_step > 60:
|
||||
break
|
||||
if max_steps > 0 and global_step > 60:
|
||||
break
|
||||
|
||||
# save finetuned transformer model
|
||||
torch.save(model.state_dict(), finetuned_model_name)
|
||||
torch.cuda.empty_cache()
|
||||
# Do some cleaning up so we can reinitialize for the next run of this function
|
||||
del lm_optimizer
|
||||
del lm_scheduler
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain data files for WikiText.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A jbl file containing tokenized data which can be split as objective dataset, "
|
||||
"train_dataset and test_dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--igf_data_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A jbl file containing the context and information gain pairs to train secondary learner.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the final fine-tuned model is stored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
|
||||
parser.add_argument(
|
||||
"--context_len",
|
||||
default=32,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--size_objective_set",
|
||||
default=100,
|
||||
type=int,
|
||||
help="number of articles that are long enough to be used as our objective set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
|
||||
)
|
||||
|
||||
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
|
||||
|
||||
parser.add_argument(
|
||||
"--secondary_learner_batch_size",
|
||||
default=128,
|
||||
type=int,
|
||||
help="batch size of training data for secondary learner",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eval_interval",
|
||||
default=10,
|
||||
type=int,
|
||||
help="decay the selectivity of our secondary learner filter from"
|
||||
"1 standard deviation above average to 1 below average after 10 batches",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
|
||||
)
|
||||
|
||||
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
|
||||
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="The threshold value used by secondary learner to filter the train_data and allow only"
|
||||
" informative data as input to the model",
|
||||
)
|
||||
|
||||
parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")
|
||||
|
||||
parser.add_argument(
|
||||
"--recopy_model",
|
||||
default=recopy_gpt2,
|
||||
type=str,
|
||||
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
|
||||
)
|
||||
|
||||
# function calls
|
||||
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
|
||||
generate_n_pairs(
|
||||
context_len=32,
|
||||
max_steps=10,
|
||||
size_objective_set=100,
|
||||
min_len=1026,
|
||||
trim=True,
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
)
|
||||
|
||||
# Load train data for secondary learner
|
||||
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
|
||||
|
||||
# Train secondary learner
|
||||
secondary_learner = training_secondary_learner(
|
||||
secondary_learner_train_data,
|
||||
secondary_learner_max_epochs=15,
|
||||
secondary_learner_batch_size=128,
|
||||
eval_freq=100,
|
||||
igf_model_path="igf_model.pt",
|
||||
)
|
||||
|
||||
# load pretrained gpt2 model
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
set_seed(42)
|
||||
|
||||
# Generate train and test data to train and evaluate gpt2 model
|
||||
train_dataset, test_dataset = generate_datasets(
|
||||
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||||
)
|
||||
|
||||
# fine-tuning of the gpt2 model using igf (Information Gain Filtration)
|
||||
finetune(
|
||||
model,
|
||||
train_dataset,
|
||||
test_dataset,
|
||||
context_len=32,
|
||||
max_steps=1000,
|
||||
batch_size=16,
|
||||
threshold=1.0,
|
||||
recopy_model=recopy_gpt2,
|
||||
secondary_learner=secondary_learner,
|
||||
eval_interval=10,
|
||||
finetuned_model_name="gpt2_finetuned.pt",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user