Add CodeParrot 🦜 codebase (#14536)

* add readme skeleton

* update readme

* add initialization script

* add deduplication script

* add codeparrot training script

* add code generation evaluation

* add validation loss script

* add requirements

* update readme

* tweak readme

* make style

* add highlights to readme

* add CLIs to scripts

* add tokenizer training script

* add docstring to constant length dataset

* fix defaults in arguments

* update readme with cli

* move image to hub

* tweaks of readme

* fix cli commands

* add author

* explain env variables

* fix formatting

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* replace generic with gpt2 tokenizer

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Leandro von Werra 2021-12-02 10:41:35 +01:00 committed by GitHub
parent e4c67d60ec
commit 43f953cc2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 942 additions and 0 deletions

View File

@ -0,0 +1,158 @@
# CodeParrot 🦜
<p align="center">
<img src="https://huggingface.co/datasets/lvwerra/repo-images/raw/main/code-highlighting-streamlit.png" alt="drawing" width="350"/>
</p>
## What is this about?
This is an open-source effort to train and evaluate code generation models. CodeParrot 🦜 is a GPT-2 model trained from scratch on Python code. The highlights of this project are:
- initialize and train a GPT-2 language model from scratch for code generation
- train a custom tokenizer adapted for Python code
- clean and deduplicate a large (>100GB) dataset with `datasets`
- train with `accelerate` on multiple GPUs using data parallelism and mixed precision
- continuously push checkpoints to the hub with `huggingface_hub`
- stream the dataset with `datasets` during training to avoid disk bottlenecks
- apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval)
## Installation
To install the dependencies simply run the following command:
```bash
pip install -r requirements.txt
```
To reproduce the results you can follow the scripts in the following sections. Note that we don't always show all possible arguments to the scripts. To get the full list of arguments with descriptions you can run the following command on any script:
```bash
python scripts/some_script.py --help
```
Before you run any of the scripts make sure you are logged in and can push to the hub:
```bash
huggingface-cli login
```
## Dataset
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
### Preprocessing
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374):
- exact deduplication using each file's hash
- filtering files with max line length > 1000
- filtering files with mean line length > 100
- fraction of alphanumeric characters < 0.25
- containing the word "auto-generated" or similar in the first 5 lines
The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/lvwerra/codeparrot-clean-train) and [validation](https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid) splits are also available on the Hub if you want to skip this step or use the data for another project.
To execute the preprocessing run the following command:
```bash
python scripts/preprocessing.py \
--dataset_name lvwerra/codeparrot \
--output_dir codeparrot-clean
```
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
## Tokenizer
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
```bash
python scripts/bpe_training.py
--base_tokenizer gpt2
--dataset_name lvwerra/codeparrot-clean-train
```
_Note:_ We originally trained the tokenizer on the unprocessed train split of the dataset `transformersbook/codeparrot-train`.
## Training
The models are randomly initialized and trained from scratch. To initialize a new model you can run:
```bash
python scripts/initialize_model.py \
--config_name gpt2-large \
--tokenizer_name lvwerra/codeparrot \
--model_name codeparrot \
--push_to_hub True
```
This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.
Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
First you need to configure `accelerate` and login to Weights & Biases:
```bash
acclerate config
wandb login
```
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
```bash
python scripts/codeparrot_training.py
```
If you want to train the small model you need to make some modifications:
```bash
accelerate launch scripts/codeparrot_training.py \
--model_ckpt lvwerra/codeparrot-small \
--train_batch_size 12 \
--eval_batch_size 12 \
--learning_rate 5e-4 \
--num_warmup_steps 2000 \
--gradient_accumulation 1 \
--gradient_checkpointing False \
--max_train_steps 150000 \
--save_checkpoint_steps 15000
```
Recall that you can see the full set of possible options with descriptions (for all scripts) by running:
```bash
python scripts/codeparrot_training.py --help
```
## Evaluation
For evaluating the language modeling loss on the validation set or any other dataset you can use the following command:
```bash
python scripts/validation_loss.py \
--model_ckpt lvwerra/codeparrot \
--dataset_name lvwerra/codeparrot-clean-valid
```
In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. You can run the evaluation with the following command:
```bash
python scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
--do_sample True \
--temperature 0.2 \
--top_p 0.95 \
--n_samples=200
```
The results as well as reference values are shown in the following table:
| Model | pass@1 | pass@10 | pass@100|
|-------|--------|---------|---------|
|CodeParrot 🦜 (110M) | 3.80% | 6.57% | 12.78% |
|CodeParrot 🦜 (1.5B) | 3.58% | 8.03% | 14.96% |
|||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|Codex (85M)| 8.22% | 12.81% | 22.40% |
|Codex (300M)| 13.17%| 20.37% | 36.27% |
|Codex (12B)| 28.81%| 46.81% | 72.31% |
|||||
|GPT-neo (125M)| 0.75% | 1.88% | 2.97% |
|GPT-neo (1.5B)| 4.79% | 7.47% | 16.30% |
|GPT-neo (2.7B)| 6.41% | 11.27% | 21.37% |
|GPT-J (6B)| 11.62% | 15.74% | 27.74% |
The numbers were obtained by sampling with `T = [0.2, 0.6, 0.8]` and picking the best value for each metric. Both CodeParrot 🦜 models are still underfitted and longer training would likely improve the performance.
## Demo
Give the model a shot yourself! There are two demos to interact with CodeParrot 🦜:
- [Code generation](https://huggingface.co/spaces/lvwerra/codeparrot-generation)
- [Code highlighting](https://huggingface.co/spaces/lvwerra/codeparrot-highlighting)
## Further Resources
A detailed description of the project can be found in the chapter "Training Transformers from Scratch" in the upcoming O'Reilly book [Natural Language Processing with Transformers](https://learning.oreilly.com/library/view/natural-language-processing/9781098103231/).
This example was provided by [Leandro von Werra](www.github.com/lvwerra).

View File

@ -0,0 +1,7 @@
transformers==4.12.2
datasets==1.16.0
accelerate==0.5.1
wandb==0.12.0
tensorboard==2.6.0
torch==1.9.0
huggingface-hub==0.0.19

View File

@ -0,0 +1,175 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TrainingArguments:
"""
Configuration for training model.
"""
model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be trained."},
)
save_dir: Optional[str] = field(
default="./",
metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
)
dataset_name_train: Optional[str] = field(
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
)
dataset_name_valid: Optional[str] = field(
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
)
train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."})
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
shuffle_buffer: Optional[int] = field(
default=1000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
)
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
num_warmup_steps: Optional[int] = field(
default=750, metadata={"help": "Number of warmup steps in the learning rate schedule."}
)
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "Number of gradient accumulation steps."}
)
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
)
max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
max_eval_steps: Optional[int] = field(
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
)
seq_length: Optional[int] = field(default=1024, metadata={"help": "Sequence lengths used for training."})
seed: Optional[int] = field(default=1, metadata={"help": "Training seed."})
save_checkpoint_steps: Optional[int] = field(
default=1024,
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
@dataclass
class EvaluationArguments:
"""
Configuration for evaluating model.
"""
model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be evaluated."},
)
dataset_name: Optional[str] = field(
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
)
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size used for evaluation."})
max_eval_steps: Optional[int] = field(
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
)
seq_length: Optional[int] = field(default=1024, metadata={"help": "Length of sequences to be evaluated."})
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
@dataclass
class HumanEvalArguments:
"""
Configuration for running evaluation on HumanEval dataset.
"""
model_ckpt: Optional[str] = field(
default="lvwerra/codeparrot",
metadata={"help": "Model name or path of model to be evaluated."},
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
do_sample: Optional[bool] = field(
default=True, metadata={"help": "Sample from the language model's output distribution."}
)
temperature: Optional[float] = field(default=0.2, metadata={"help": "Sampling temperature used for generation."})
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "Maximum number of newly generated tokens."})
top_k: Optional[int] = field(default=0, metadata={"help": "Top-k parameter used for generation."})
top_p: Optional[float] = field(default=0.95, metadata={"help": "Top-p parameter used for nucleus sampling."})
batch_size: Optional[int] = field(default=10, metadata={"help": "Number of generations to run in parallel."})
n_samples: Optional[int] = field(
default=200, metadata={"help": "Number of completions to generate for each sample."}
)
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
output_file: Optional[str] = field(
default="eval_results.json", metadata={"help": "Random seed used for evaluation."}
)
HF_ALLOW_CODE_EVAL: Optional[str] = field(
default="0", metadata={"help": "Allow `code_eval` to execute Python code on machine"}
)
@dataclass
class PreprocessingArguments:
"""
Configuration for preprocessing data.
"""
num_workers: Optional[int] = field(
default=None,
metadata={
"help": "The number of CPU cores to use for parallel preprocessing. Default uses the maximum available."
},
)
dataset_name: Optional[str] = field(
default="codeparrot", metadata={"help": "Folder or name of dataset to process."}
)
output_dir: Optional[str] = field(
default="codeparrot-clean", metadata={"help": "Folder to save processed processed dataset."}
)
samples_per_file: Optional[int] = field(
default=100_000, metadata={"help": "Number of files to save per JSON output file."}
)
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
line_max: Optional[float] = field(
default=1000, metadata={"help": "Maximum line length in file, otherwise file is filtered."}
)
line_mean: Optional[float] = field(
default=100, metadata={"help": "Maximum mean line length in file, otherwise file is filtered."}
)
alpha_frac: Optional[float] = field(
default=0.25, metadata={"help": "Maximum fraction of non-alphanumeric characters, otherwise file is filtered."}
)
@dataclass
class TokenizerTrainingArguments:
"""
Configuration for tokenizer training.
"""
base_tokenizer: Optional[str] = field(
default="gpt2",
metadata={"help": "Base tokenizer to build new tokenizer from."},
)
dataset_name: Optional[str] = field(
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
)
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
n_examples: Optional[int] = field(
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
)
tokenizer_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of new tokenizer."})
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
@dataclass
class InitializationArguments:
"""
Configuration for initializing new model.
"""
config_name: Optional[str] = field(
default="gpt2-large",
metadata={"help": "Configuration to use for model initialization."},
)
tokenizer_name: Optional[str] = field(
default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}
)
model_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of the created model."})
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})

View File

@ -0,0 +1,32 @@
from datasets import load_dataset
from tqdm import tqdm
from arguments import TokenizerTrainingArguments
from transformers import GPT2Tokenizer, HfArgumentParser
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
# Iterator for Training
def batch_iterator(batch_size=10):
for _ in tqdm(range(0, args.n_examples, batch_size)):
yield [next(iter_dataset)[args.text_column] for _ in range(batch_size)]
# Configuration
parser = HfArgumentParser(TokenizerTrainingArguments)
args = parser.parse_args()
# Base tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(args.base_tokenizer)
base_vocab = list(bytes_to_unicode().values())
# Load dataset
dataset = load_dataset(args.dataset_name, split="train", streaming=True)
iter_dataset = iter(dataset)
# Training and saving
new_tokenizer = tokenizer.train_new_from_iterator(
batch_iterator(), vocab_size=args.vocab_size, initial_alphabet=base_vocab
)
new_tokenizer.save_pretrained(args.tokenizer_name, push_to_hub=args.push_to_hub)

View File

@ -0,0 +1,240 @@
import logging
from argparse import Namespace
from pathlib import Path
import datasets
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
import transformers
import wandb
from accelerate import Accelerator
from arguments import TrainingArguments
from huggingface_hub import Repository
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences: Number of token sequences to keep in buffer.
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.input_characters:
break
try:
buffer.append(next(iterator)["content"])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
self.epoch += 1
logger.info(f"Dataset epoch: {self.epoch}")
else:
more_examples = False
break
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
yield torch.tensor(input_ids)
def setup_logging(args):
project_name = args.model_ckpt.split("/")[-1]
logger = logging.getLogger(__name__)
log_dir = Path(args.save_dir) / "log/"
log_dir.mkdir(exist_ok=True)
filename = f"debug_{accelerator.process_index}.log"
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
)
if accelerator.is_main_process: # we only want to setup logging once
wandb.init(project=project_name, config=args)
run_name = wandb.run.name
tb_writer = SummaryWriter()
tb_writer.add_hparams(vars(args), {"0": 0})
logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info()
else:
tb_writer = None
run_name = ""
logger.setLevel(logging.ERROR)
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
return logger, tb_writer, run_name
def create_dataloaders(args):
ds_kwargs = {"streaming": True}
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
params_without_wd.append(p)
else:
params_with_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]
def log_metrics(step, metrics):
logger.info(f"Step {step}: {metrics}")
if accelerator.is_main_process:
wandb.log(metrics)
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
def evaluate(args):
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(batch, labels=batch)
loss = outputs.loss.repeat(args.valid_batch_size)
losses.append(accelerator.gather(loss))
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
break
loss = torch.mean(torch.cat(losses))
try:
perplexity = torch.exp(loss)
except OverflowError:
perplexity = float("inf")
return loss.item(), perplexity.item()
# Accelerator
accelerator = Accelerator()
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
# Settings
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args()
args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
set_seed(args.seed)
# Clone model repository
if accelerator.is_main_process:
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
# Logging
logger, tb_writer, run_name = setup_logging(args)
logger.info(accelerator.state)
# Checkout new branch on repo
if accelerator.is_main_process:
hf_repo.git_checkout(run_name, create_branch_ok=True)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
# Load dataset and dataloader
train_dataloader, eval_dataloader = create_dataloaders(args)
# Prepare the optimizer and learning rate scheduler
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
def get_lr():
return optimizer.param_groups[0]["lr"]
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# Train model
model.train()
completed_steps = 0
for step, batch in enumerate(train_dataloader, start=1):
loss = model(batch, labels=batch, use_cache=False).loss
log_metrics(
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
completed_steps += 1
if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message=f"step {step}")
model.train()
if completed_steps >= args.max_train_steps:
break
# Evaluate and save the last checkpoint
logger.info("Evaluating and saving model after training")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message="final model")

View File

@ -0,0 +1,87 @@
import json
import multiprocessing
import os
import re
from datasets import load_dataset, load_metric
from tqdm import tqdm
import transformers
from arguments import HumanEvalArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed
def first_block(string):
"""Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
"""Complete prompt with text generation pipeline and return num_completions."""
prompt = pipe.tokenizer.eos_token + prompt
code_gens = pipe(prompt, num_return_sequences=num_completions, **gen_kwargs)
return [first_block(code_gen["generated_text"][len(prompt) :]) for code_gen in code_gens]
def main():
# Setup configuration
parser = HfArgumentParser(HumanEvalArguments)
args = parser.parse_args()
transformers.logging.set_verbosity_error()
# enables code execution in code_eval metric
os.environ["HF_ALLOW_CODE_EVAL"] = args.HF_ALLOW_CODE_EVAL
# make sure tokenizer plays nice with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
set_seed(args.seed)
# Generation settings
gen_kwargs = {
"do_sample": args.do_sample,
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"top_p": args.top_p,
"top_k": args.top_k,
}
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
# Load evaluation dataset and metric
human_eval = load_dataset("openai_humaneval")
code_eval_metric = load_metric("code_eval")
# Generate completions for evaluation set
n_tasks = 4 # len(human_eval["test"])
generations, references = [], []
for task in tqdm(range(n_tasks)):
task_generations = []
prompt = human_eval["test"][task]["prompt"].strip()
for batch in range(args.n_samples // args.batch_size):
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
generations.append([prompt + gen for gen in task_generations])
test_func = human_eval["test"][task]["test"]
entry_point = f"check({human_eval['test'][task]['entry_point']})"
references.append("\n" + test_func + "\n" + entry_point)
# Evaluate completions with "code_eval" metric
pass_at_k, _ = code_eval_metric.compute(
references=references, predictions=generations, num_workers=args.num_workers
)
print(f"Results: {pass_at_k}")
# Save results to json file
with open(args.output_file, "w") as fp:
json.dump(pass_at_k, fp)
# For some reason the folliwng seems to be necessary sometimes for code_eval to work nice with multiprocessing
# https://stackoverflow.com/questions/60804599/python-multiprocessing-keeps-spawning-the-whole-script
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
from arguments import InitializationArguments
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
# Configuration
parser = HfArgumentParser(InitializationArguments)
args = parser.parse_args()
# Load codeparrot tokenizer trained for Python code tokenization
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
config_kwargs = {"vocab_size": len(tokenizer), "scale_attn_by_layer_idx": True, "reorder_and_upcast_attn": True}
# Load model config (GPT-2 large in this case)
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
# Initialize new model with config
model = AutoModelForCausalLM(config)
# Save model to the hub
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)

View File

@ -0,0 +1,122 @@
import gzip
import multiprocessing
import os
import shutil
import time
import numpy as np
from datasets import load_dataset
from arguments import PreprocessingArguments
from transformers import HfArgumentParser
def get_hash(example):
"""Get hash of content field."""
return {"hash": hash(example["content"])}
def line_stats(example):
"""Calculates mean and max line length of file."""
line_lengths = [len(line) for line in example["content"].splitlines()]
return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)}
def alpha_stats(example):
"""Calculates mean and max line length of file."""
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
return {"alpha_frac": alpha_frac}
def check_uniques(example, uniques):
"""Check if current hash is still in set of unique hashes and remove if true."""
if example["hash"] in uniques:
uniques.remove(example["hash"])
return True
else:
return False
def is_autogenerated(example, scan_width=5):
"""Check if file is autogenerated by looking for keywords in the first few lines of the file."""
keywords = ["auto-generated", "autogenerated", "automatically generated"]
lines = example["content"].splitlines()
for _, line in zip(range(scan_width), lines):
for keyword in keywords:
if keyword in line.lower():
return {"autogenerated": True}
else:
return {"autogenerated": False}
def preprocess(example):
"""Chain all preprocessing steps into one function to not fill cache."""
results = dict()
results.update(get_hash(example))
results.update(line_stats(example))
results.update(alpha_stats(example))
results.update(is_autogenerated(example))
return results
def filter(example, uniques, args):
"""Filter dataset with heuristics."""
if not check_uniques(example, uniques):
return False
elif example["autogenerated"]:
return False
elif example["line_max"] > args.line_max:
return False
elif example["line_mean"] > args.line_mean:
return False
elif example["alpha_frac"] < args.alpha_frac:
return False
else:
return True
def compress_file(file_path):
"""Compress a file with g-zip."""
with open(file_path, "rb") as f_in:
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(file_path)
# Settings
parser = HfArgumentParser(PreprocessingArguments)
args = parser.parse_args()
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
# Load dataset
t_start = time.time()
ds = load_dataset(args.dataset_name, split="train")
print(f"Time to load dataset: {time.time()-t_start:.2f}")
# Run preprocessing
t_start = time.time()
ds = ds.map(preprocess, num_proc=args.num_workers)
print(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
# Deduplicate hashes
uniques = set(ds.unique("hash"))
frac = len(uniques) / len(ds)
print(f"Fraction of duplicates: {1-frac:.2%}")
# Deduplicate data and apply heuristics
t_start = time.time()
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
print(f"Size of filtered dataset: {len(ds_filter)}")
# Save data in batches of samples_per_file
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
t_start = time.time()
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
end_index = min(len(ds_filter), index + args.samples_per_file)
ds_filter.select(list(range(index, end_index))).to_json(file_path)
compress_file(file_path)
print(f"Time to save dataset: {time.time()-t_start:.2f}")

View File

@ -0,0 +1,99 @@
import logging
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from accelerate import Accelerator
from arguments import EvaluationArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
class ConstantLengthDataset(IterableDataset):
def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.input_characters = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.input_characters:
break
try:
buffer.append(next(iterator)["content"])
buffer_len += len(buffer[-1])
except StopIteration:
more_examples = False
break
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
yield torch.tensor(input_ids)
def create_dataloader(args):
ds_kwargs = {"streaming": True}
valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs)
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size)
return eval_dataloader
def evaluate(args):
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(batch, labels=batch)
loss = outputs.loss.repeat(args.batch_size)
losses.append(accelerator.gather(loss))
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
break
loss = torch.mean(torch.cat(losses))
try:
perplexity = torch.exp(loss)
except OverflowError:
perplexity = float("inf")
return loss.item(), perplexity.item()
# Setup Accelerator
accelerator = Accelerator()
# Parse configuration
parser = HfArgumentParser(EvaluationArguments)
args = parser.parse_args()
set_seed(args.seed)
# Logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
# Load dataset and dataloader
eval_dataloader = create_dataloader(args)
# Prepare everything with our `accelerator`.
model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
# Evaluate and save the last checkpoint
logger.info("Evaluating and saving model after training")
eval_loss, perplexity = evaluate(args)
logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}")