mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add CodeParrot 🦜 codebase (#14536)
* add readme skeleton * update readme * add initialization script * add deduplication script * add codeparrot training script * add code generation evaluation * add validation loss script * add requirements * update readme * tweak readme * make style * add highlights to readme * add CLIs to scripts * add tokenizer training script * add docstring to constant length dataset * fix defaults in arguments * update readme with cli * move image to hub * tweaks of readme * fix cli commands * add author * explain env variables * fix formatting * Update examples/research_projects/codeparrot/README.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * replace generic with gpt2 tokenizer Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
e4c67d60ec
commit
43f953cc2e
158
examples/research_projects/codeparrot/README.md
Normal file
158
examples/research_projects/codeparrot/README.md
Normal file
@ -0,0 +1,158 @@
|
||||
# CodeParrot 🦜
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/lvwerra/repo-images/raw/main/code-highlighting-streamlit.png" alt="drawing" width="350"/>
|
||||
</p>
|
||||
|
||||
## What is this about?
|
||||
This is an open-source effort to train and evaluate code generation models. CodeParrot 🦜 is a GPT-2 model trained from scratch on Python code. The highlights of this project are:
|
||||
- initialize and train a GPT-2 language model from scratch for code generation
|
||||
- train a custom tokenizer adapted for Python code
|
||||
- clean and deduplicate a large (>100GB) dataset with `datasets`
|
||||
- train with `accelerate` on multiple GPUs using data parallelism and mixed precision
|
||||
- continuously push checkpoints to the hub with `huggingface_hub`
|
||||
- stream the dataset with `datasets` during training to avoid disk bottlenecks
|
||||
- apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval)
|
||||
|
||||
## Installation
|
||||
To install the dependencies simply run the following command:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
To reproduce the results you can follow the scripts in the following sections. Note that we don't always show all possible arguments to the scripts. To get the full list of arguments with descriptions you can run the following command on any script:
|
||||
|
||||
```bash
|
||||
python scripts/some_script.py --help
|
||||
```
|
||||
|
||||
Before you run any of the scripts make sure you are logged in and can push to the hub:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Dataset
|
||||
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
|
||||
|
||||
### Preprocessing
|
||||
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374):
|
||||
|
||||
- exact deduplication using each file's hash
|
||||
- filtering files with max line length > 1000
|
||||
- filtering files with mean line length > 100
|
||||
- fraction of alphanumeric characters < 0.25
|
||||
- containing the word "auto-generated" or similar in the first 5 lines
|
||||
|
||||
The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/lvwerra/codeparrot-clean-train) and [validation](https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid) splits are also available on the Hub if you want to skip this step or use the data for another project.
|
||||
|
||||
To execute the preprocessing run the following command:
|
||||
```bash
|
||||
python scripts/preprocessing.py \
|
||||
--dataset_name lvwerra/codeparrot \
|
||||
--output_dir codeparrot-clean
|
||||
```
|
||||
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
|
||||
|
||||
## Tokenizer
|
||||
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
|
||||
```bash
|
||||
python scripts/bpe_training.py
|
||||
--base_tokenizer gpt2
|
||||
--dataset_name lvwerra/codeparrot-clean-train
|
||||
```
|
||||
|
||||
_Note:_ We originally trained the tokenizer on the unprocessed train split of the dataset `transformersbook/codeparrot-train`.
|
||||
|
||||
## Training
|
||||
The models are randomly initialized and trained from scratch. To initialize a new model you can run:
|
||||
|
||||
```bash
|
||||
python scripts/initialize_model.py \
|
||||
--config_name gpt2-large \
|
||||
--tokenizer_name lvwerra/codeparrot \
|
||||
--model_name codeparrot \
|
||||
--push_to_hub True
|
||||
```
|
||||
This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.
|
||||
|
||||
Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
|
||||
|
||||
First you need to configure `accelerate` and login to Weights & Biases:
|
||||
|
||||
```bash
|
||||
acclerate config
|
||||
wandb login
|
||||
```
|
||||
|
||||
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
|
||||
|
||||
```bash
|
||||
python scripts/codeparrot_training.py
|
||||
```
|
||||
|
||||
If you want to train the small model you need to make some modifications:
|
||||
|
||||
```bash
|
||||
accelerate launch scripts/codeparrot_training.py \
|
||||
--model_ckpt lvwerra/codeparrot-small \
|
||||
--train_batch_size 12 \
|
||||
--eval_batch_size 12 \
|
||||
--learning_rate 5e-4 \
|
||||
--num_warmup_steps 2000 \
|
||||
--gradient_accumulation 1 \
|
||||
--gradient_checkpointing False \
|
||||
--max_train_steps 150000 \
|
||||
--save_checkpoint_steps 15000
|
||||
```
|
||||
|
||||
Recall that you can see the full set of possible options with descriptions (for all scripts) by running:
|
||||
|
||||
```bash
|
||||
python scripts/codeparrot_training.py --help
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
For evaluating the language modeling loss on the validation set or any other dataset you can use the following command:
|
||||
```bash
|
||||
python scripts/validation_loss.py \
|
||||
--model_ckpt lvwerra/codeparrot \
|
||||
--dataset_name lvwerra/codeparrot-clean-valid
|
||||
```
|
||||
In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. You can run the evaluation with the following command:
|
||||
|
||||
```bash
|
||||
python scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
|
||||
--do_sample True \
|
||||
--temperature 0.2 \
|
||||
--top_p 0.95 \
|
||||
--n_samples=200
|
||||
```
|
||||
|
||||
The results as well as reference values are shown in the following table:
|
||||
|
||||
| Model | pass@1 | pass@10 | pass@100|
|
||||
|-------|--------|---------|---------|
|
||||
|CodeParrot 🦜 (110M) | 3.80% | 6.57% | 12.78% |
|
||||
|CodeParrot 🦜 (1.5B) | 3.58% | 8.03% | 14.96% |
|
||||
|||||
|
||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|
||||
|Codex (85M)| 8.22% | 12.81% | 22.40% |
|
||||
|Codex (300M)| 13.17%| 20.37% | 36.27% |
|
||||
|Codex (12B)| 28.81%| 46.81% | 72.31% |
|
||||
|||||
|
||||
|GPT-neo (125M)| 0.75% | 1.88% | 2.97% |
|
||||
|GPT-neo (1.5B)| 4.79% | 7.47% | 16.30% |
|
||||
|GPT-neo (2.7B)| 6.41% | 11.27% | 21.37% |
|
||||
|GPT-J (6B)| 11.62% | 15.74% | 27.74% |
|
||||
|
||||
The numbers were obtained by sampling with `T = [0.2, 0.6, 0.8]` and picking the best value for each metric. Both CodeParrot 🦜 models are still underfitted and longer training would likely improve the performance.
|
||||
|
||||
## Demo
|
||||
Give the model a shot yourself! There are two demos to interact with CodeParrot 🦜:
|
||||
- [Code generation](https://huggingface.co/spaces/lvwerra/codeparrot-generation)
|
||||
- [Code highlighting](https://huggingface.co/spaces/lvwerra/codeparrot-highlighting)
|
||||
|
||||
## Further Resources
|
||||
A detailed description of the project can be found in the chapter "Training Transformers from Scratch" in the upcoming O'Reilly book [Natural Language Processing with Transformers](https://learning.oreilly.com/library/view/natural-language-processing/9781098103231/).
|
||||
|
||||
This example was provided by [Leandro von Werra](www.github.com/lvwerra).
|
7
examples/research_projects/codeparrot/requirements.txt
Normal file
7
examples/research_projects/codeparrot/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
transformers==4.12.2
|
||||
datasets==1.16.0
|
||||
accelerate==0.5.1
|
||||
wandb==0.12.0
|
||||
tensorboard==2.6.0
|
||||
torch==1.9.0
|
||||
huggingface-hub==0.0.19
|
175
examples/research_projects/codeparrot/scripts/arguments.py
Normal file
175
examples/research_projects/codeparrot/scripts/arguments.py
Normal file
@ -0,0 +1,175 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
"""
|
||||
Configuration for training model.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="lvwerra/codeparrot",
|
||||
metadata={"help": "Model name or path of model to be trained."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default="./",
|
||||
metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
|
||||
)
|
||||
dataset_name_train: Optional[str] = field(
|
||||
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
|
||||
)
|
||||
dataset_name_valid: Optional[str] = field(
|
||||
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
|
||||
)
|
||||
train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."})
|
||||
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
|
||||
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
|
||||
shuffle_buffer: Optional[int] = field(
|
||||
default=1000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
|
||||
num_warmup_steps: Optional[int] = field(
|
||||
default=750, metadata={"help": "Number of warmup steps in the learning rate schedule."}
|
||||
)
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "Number of gradient accumulation steps."}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
|
||||
)
|
||||
max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
|
||||
max_eval_steps: Optional[int] = field(
|
||||
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
|
||||
)
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "Sequence lengths used for training."})
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Training seed."})
|
||||
save_checkpoint_steps: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
"""
|
||||
Configuration for evaluating model.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="lvwerra/codeparrot",
|
||||
metadata={"help": "Model name or path of model to be evaluated."},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
|
||||
)
|
||||
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size used for evaluation."})
|
||||
max_eval_steps: Optional[int] = field(
|
||||
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
|
||||
)
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "Length of sequences to be evaluated."})
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanEvalArguments:
|
||||
"""
|
||||
Configuration for running evaluation on HumanEval dataset.
|
||||
"""
|
||||
|
||||
model_ckpt: Optional[str] = field(
|
||||
default="lvwerra/codeparrot",
|
||||
metadata={"help": "Model name or path of model to be evaluated."},
|
||||
)
|
||||
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Sample from the language model's output distribution."}
|
||||
)
|
||||
temperature: Optional[float] = field(default=0.2, metadata={"help": "Sampling temperature used for generation."})
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "Maximum number of newly generated tokens."})
|
||||
top_k: Optional[int] = field(default=0, metadata={"help": "Top-k parameter used for generation."})
|
||||
top_p: Optional[float] = field(default=0.95, metadata={"help": "Top-p parameter used for nucleus sampling."})
|
||||
batch_size: Optional[int] = field(default=10, metadata={"help": "Number of generations to run in parallel."})
|
||||
n_samples: Optional[int] = field(
|
||||
default=200, metadata={"help": "Number of completions to generate for each sample."}
|
||||
)
|
||||
seed: Optional[int] = field(default=1, metadata={"help": "Random seed used for evaluation."})
|
||||
output_file: Optional[str] = field(
|
||||
default="eval_results.json", metadata={"help": "Random seed used for evaluation."}
|
||||
)
|
||||
HF_ALLOW_CODE_EVAL: Optional[str] = field(
|
||||
default="0", metadata={"help": "Allow `code_eval` to execute Python code on machine"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessingArguments:
|
||||
"""
|
||||
Configuration for preprocessing data.
|
||||
"""
|
||||
|
||||
num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of CPU cores to use for parallel preprocessing. Default uses the maximum available."
|
||||
},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="codeparrot", metadata={"help": "Folder or name of dataset to process."}
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default="codeparrot-clean", metadata={"help": "Folder to save processed processed dataset."}
|
||||
)
|
||||
samples_per_file: Optional[int] = field(
|
||||
default=100_000, metadata={"help": "Number of files to save per JSON output file."}
|
||||
)
|
||||
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
|
||||
line_max: Optional[float] = field(
|
||||
default=1000, metadata={"help": "Maximum line length in file, otherwise file is filtered."}
|
||||
)
|
||||
line_mean: Optional[float] = field(
|
||||
default=100, metadata={"help": "Maximum mean line length in file, otherwise file is filtered."}
|
||||
)
|
||||
alpha_frac: Optional[float] = field(
|
||||
default=0.25, metadata={"help": "Maximum fraction of non-alphanumeric characters, otherwise file is filtered."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerTrainingArguments:
|
||||
"""
|
||||
Configuration for tokenizer training.
|
||||
"""
|
||||
|
||||
base_tokenizer: Optional[str] = field(
|
||||
default="gpt2",
|
||||
metadata={"help": "Base tokenizer to build new tokenizer from."},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
|
||||
)
|
||||
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
|
||||
vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
|
||||
n_examples: Optional[int] = field(
|
||||
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of new tokenizer."})
|
||||
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitializationArguments:
|
||||
"""
|
||||
Configuration for initializing new model.
|
||||
"""
|
||||
|
||||
config_name: Optional[str] = field(
|
||||
default="gpt2-large",
|
||||
metadata={"help": "Configuration to use for model initialization."},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}
|
||||
)
|
||||
model_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of the created model."})
|
||||
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
|
@ -0,0 +1,32 @@
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from transformers import GPT2Tokenizer, HfArgumentParser
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
|
||||
# Iterator for Training
|
||||
def batch_iterator(batch_size=10):
|
||||
for _ in tqdm(range(0, args.n_examples, batch_size)):
|
||||
yield [next(iter_dataset)[args.text_column] for _ in range(batch_size)]
|
||||
|
||||
|
||||
# Configuration
|
||||
parser = HfArgumentParser(TokenizerTrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Base tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.base_tokenizer)
|
||||
base_vocab = list(bytes_to_unicode().values())
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset(args.dataset_name, split="train", streaming=True)
|
||||
iter_dataset = iter(dataset)
|
||||
|
||||
|
||||
# Training and saving
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
batch_iterator(), vocab_size=args.vocab_size, initial_alphabet=base_vocab
|
||||
)
|
||||
new_tokenizer.save_pretrained(args.tokenizer_name, push_to_hub=args.push_to_hub)
|
@ -0,0 +1,240 @@
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import transformers
|
||||
import wandb
|
||||
from accelerate import Accelerator
|
||||
from arguments import TrainingArguments
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
num_of_sequences: Number of token sequences to keep in buffer.
|
||||
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.bos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||
self.epoch = 0
|
||||
self.infinite = infinite
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.input_characters:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)["content"])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
iterator = iter(self.dataset)
|
||||
self.epoch += 1
|
||||
logger.info(f"Dataset epoch: {self.epoch}")
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
|
||||
def setup_logging(args):
|
||||
project_name = args.model_ckpt.split("/")[-1]
|
||||
logger = logging.getLogger(__name__)
|
||||
log_dir = Path(args.save_dir) / "log/"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
filename = f"debug_{accelerator.process_index}.log"
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||
)
|
||||
if accelerator.is_main_process: # we only want to setup logging once
|
||||
wandb.init(project=project_name, config=args)
|
||||
run_name = wandb.run.name
|
||||
tb_writer = SummaryWriter()
|
||||
tb_writer.add_hparams(vars(args), {"0": 0})
|
||||
logger.setLevel(logging.INFO)
|
||||
datasets.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
tb_writer = None
|
||||
run_name = ""
|
||||
logger.setLevel(logging.ERROR)
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
return logger, tb_writer, run_name
|
||||
|
||||
|
||||
def create_dataloaders(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
||||
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
||||
params_with_wd, params_without_wd = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if any(nd in n for nd in no_decay):
|
||||
params_without_wd.append(p)
|
||||
else:
|
||||
params_with_wd.append(p)
|
||||
return [
|
||||
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
||||
{"params": params_without_wd, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
|
||||
def log_metrics(step, metrics):
|
||||
logger.info(f"Step {step}: {metrics}")
|
||||
if accelerator.is_main_process:
|
||||
wandb.log(metrics)
|
||||
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.valid_batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
|
||||
|
||||
# Accelerator
|
||||
accelerator = Accelerator()
|
||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = Namespace(**vars(args), **acc_state)
|
||||
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
||||
set_seed(args.seed)
|
||||
|
||||
# Clone model repository
|
||||
if accelerator.is_main_process:
|
||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||
|
||||
# Logging
|
||||
logger, tb_writer, run_name = setup_logging(args)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Checkout new branch on repo
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.git_checkout(run_name, create_branch_ok=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
|
||||
|
||||
# Load dataset and dataloader
|
||||
train_dataloader, eval_dataloader = create_dataloaders(args)
|
||||
|
||||
# Prepare the optimizer and learning rate scheduler
|
||||
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
|
||||
def get_lr():
|
||||
return optimizer.param_groups[0]["lr"]
|
||||
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Train model
|
||||
model.train()
|
||||
completed_steps = 0
|
||||
for step, batch in enumerate(train_dataloader, start=1):
|
||||
loss = model(batch, labels=batch, use_cache=False).loss
|
||||
log_metrics(
|
||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
||||
)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
completed_steps += 1
|
||||
if step % args.save_checkpoint_steps == 0:
|
||||
logger.info("Evaluating and saving model checkpoint")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||
model.train()
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message="final model")
|
87
examples/research_projects/codeparrot/scripts/human_eval.py
Normal file
87
examples/research_projects/codeparrot/scripts/human_eval.py
Normal file
@ -0,0 +1,87 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from arguments import HumanEvalArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed
|
||||
|
||||
|
||||
def first_block(string):
|
||||
"""Split off first block of code by scanning for class, def etc. on newlines."""
|
||||
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()
|
||||
|
||||
|
||||
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
|
||||
"""Complete prompt with text generation pipeline and return num_completions."""
|
||||
prompt = pipe.tokenizer.eos_token + prompt
|
||||
code_gens = pipe(prompt, num_return_sequences=num_completions, **gen_kwargs)
|
||||
return [first_block(code_gen["generated_text"][len(prompt) :]) for code_gen in code_gens]
|
||||
|
||||
|
||||
def main():
|
||||
# Setup configuration
|
||||
parser = HfArgumentParser(HumanEvalArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
# enables code execution in code_eval metric
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = args.HF_ALLOW_CODE_EVAL
|
||||
# make sure tokenizer plays nice with multiprocessing
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
|
||||
set_seed(args.seed)
|
||||
|
||||
# Generation settings
|
||||
gen_kwargs = {
|
||||
"do_sample": args.do_sample,
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
|
||||
|
||||
# Load evaluation dataset and metric
|
||||
human_eval = load_dataset("openai_humaneval")
|
||||
code_eval_metric = load_metric("code_eval")
|
||||
|
||||
# Generate completions for evaluation set
|
||||
n_tasks = 4 # len(human_eval["test"])
|
||||
generations, references = [], []
|
||||
for task in tqdm(range(n_tasks)):
|
||||
task_generations = []
|
||||
prompt = human_eval["test"][task]["prompt"].strip()
|
||||
for batch in range(args.n_samples // args.batch_size):
|
||||
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
|
||||
generations.append([prompt + gen for gen in task_generations])
|
||||
test_func = human_eval["test"][task]["test"]
|
||||
entry_point = f"check({human_eval['test'][task]['entry_point']})"
|
||||
references.append("\n" + test_func + "\n" + entry_point)
|
||||
|
||||
# Evaluate completions with "code_eval" metric
|
||||
pass_at_k, _ = code_eval_metric.compute(
|
||||
references=references, predictions=generations, num_workers=args.num_workers
|
||||
)
|
||||
print(f"Results: {pass_at_k}")
|
||||
|
||||
# Save results to json file
|
||||
with open(args.output_file, "w") as fp:
|
||||
json.dump(pass_at_k, fp)
|
||||
|
||||
|
||||
# For some reason the folliwng seems to be necessary sometimes for code_eval to work nice with multiprocessing
|
||||
# https://stackoverflow.com/questions/60804599/python-multiprocessing-keeps-spawning-the-whole-script
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,22 @@
|
||||
from arguments import InitializationArguments
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
# Configuration
|
||||
parser = HfArgumentParser(InitializationArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load codeparrot tokenizer trained for Python code tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
|
||||
|
||||
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
|
||||
config_kwargs = {"vocab_size": len(tokenizer), "scale_attn_by_layer_idx": True, "reorder_and_upcast_attn": True}
|
||||
|
||||
# Load model config (GPT-2 large in this case)
|
||||
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
|
||||
|
||||
# Initialize new model with config
|
||||
model = AutoModelForCausalLM(config)
|
||||
|
||||
# Save model to the hub
|
||||
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
|
122
examples/research_projects/codeparrot/scripts/preprocessing.py
Normal file
122
examples/research_projects/codeparrot/scripts/preprocessing.py
Normal file
@ -0,0 +1,122 @@
|
||||
import gzip
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PreprocessingArguments
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
def get_hash(example):
|
||||
"""Get hash of content field."""
|
||||
return {"hash": hash(example["content"])}
|
||||
|
||||
|
||||
def line_stats(example):
|
||||
"""Calculates mean and max line length of file."""
|
||||
line_lengths = [len(line) for line in example["content"].splitlines()]
|
||||
return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)}
|
||||
|
||||
|
||||
def alpha_stats(example):
|
||||
"""Calculates mean and max line length of file."""
|
||||
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
|
||||
return {"alpha_frac": alpha_frac}
|
||||
|
||||
|
||||
def check_uniques(example, uniques):
|
||||
"""Check if current hash is still in set of unique hashes and remove if true."""
|
||||
if example["hash"] in uniques:
|
||||
uniques.remove(example["hash"])
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_autogenerated(example, scan_width=5):
|
||||
"""Check if file is autogenerated by looking for keywords in the first few lines of the file."""
|
||||
keywords = ["auto-generated", "autogenerated", "automatically generated"]
|
||||
lines = example["content"].splitlines()
|
||||
for _, line in zip(range(scan_width), lines):
|
||||
for keyword in keywords:
|
||||
if keyword in line.lower():
|
||||
return {"autogenerated": True}
|
||||
else:
|
||||
return {"autogenerated": False}
|
||||
|
||||
|
||||
def preprocess(example):
|
||||
"""Chain all preprocessing steps into one function to not fill cache."""
|
||||
results = dict()
|
||||
results.update(get_hash(example))
|
||||
results.update(line_stats(example))
|
||||
results.update(alpha_stats(example))
|
||||
results.update(is_autogenerated(example))
|
||||
return results
|
||||
|
||||
|
||||
def filter(example, uniques, args):
|
||||
"""Filter dataset with heuristics."""
|
||||
if not check_uniques(example, uniques):
|
||||
return False
|
||||
elif example["autogenerated"]:
|
||||
return False
|
||||
elif example["line_max"] > args.line_max:
|
||||
return False
|
||||
elif example["line_mean"] > args.line_mean:
|
||||
return False
|
||||
elif example["alpha_frac"] < args.alpha_frac:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def compress_file(file_path):
|
||||
"""Compress a file with g-zip."""
|
||||
with open(file_path, "rb") as f_in:
|
||||
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(PreprocessingArguments)
|
||||
args = parser.parse_args()
|
||||
if args.num_workers is None:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
|
||||
# Load dataset
|
||||
t_start = time.time()
|
||||
ds = load_dataset(args.dataset_name, split="train")
|
||||
print(f"Time to load dataset: {time.time()-t_start:.2f}")
|
||||
|
||||
# Run preprocessing
|
||||
t_start = time.time()
|
||||
ds = ds.map(preprocess, num_proc=args.num_workers)
|
||||
print(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
|
||||
|
||||
# Deduplicate hashes
|
||||
uniques = set(ds.unique("hash"))
|
||||
frac = len(uniques) / len(ds)
|
||||
print(f"Fraction of duplicates: {1-frac:.2%}")
|
||||
|
||||
# Deduplicate data and apply heuristics
|
||||
t_start = time.time()
|
||||
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
|
||||
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
||||
print(f"Size of filtered dataset: {len(ds_filter)}")
|
||||
|
||||
# Save data in batches of samples_per_file
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
t_start = time.time()
|
||||
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
||||
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
|
||||
end_index = min(len(ds_filter), index + args.samples_per_file)
|
||||
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
||||
compress_file(file_path)
|
||||
print(f"Time to save dataset: {time.time()-t_start:.2f}")
|
@ -0,0 +1,99 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.bos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.input_characters:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)["content"])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
|
||||
def create_dataloader(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size)
|
||||
return eval_dataloader
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
|
||||
|
||||
# Setup Accelerator
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Parse configuration
|
||||
parser = HfArgumentParser(EvaluationArguments)
|
||||
args = parser.parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
# Logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
||||
|
||||
# Load dataset and dataloader
|
||||
eval_dataloader = create_dataloader(args)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}")
|
Loading…
Reference in New Issue
Block a user