mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* add readme skeleton * update readme * add initialization script * add deduplication script * add codeparrot training script * add code generation evaluation * add validation loss script * add requirements * update readme * tweak readme * make style * add highlights to readme * add CLIs to scripts * add tokenizer training script * add docstring to constant length dataset * fix defaults in arguments * update readme with cli * move image to hub * tweaks of readme * fix cli commands * add author * explain env variables * fix formatting * Update examples/research_projects/codeparrot/README.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * replace generic with gpt2 tokenizer Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import gzip
|
|
import multiprocessing
|
|
import os
|
|
import shutil
|
|
import time
|
|
|
|
import numpy as np
|
|
from datasets import load_dataset
|
|
|
|
from arguments import PreprocessingArguments
|
|
from transformers import HfArgumentParser
|
|
|
|
|
|
def get_hash(example):
|
|
"""Get hash of content field."""
|
|
return {"hash": hash(example["content"])}
|
|
|
|
|
|
def line_stats(example):
|
|
"""Calculates mean and max line length of file."""
|
|
line_lengths = [len(line) for line in example["content"].splitlines()]
|
|
return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)}
|
|
|
|
|
|
def alpha_stats(example):
|
|
"""Calculates mean and max line length of file."""
|
|
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
|
|
return {"alpha_frac": alpha_frac}
|
|
|
|
|
|
def check_uniques(example, uniques):
|
|
"""Check if current hash is still in set of unique hashes and remove if true."""
|
|
if example["hash"] in uniques:
|
|
uniques.remove(example["hash"])
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_autogenerated(example, scan_width=5):
|
|
"""Check if file is autogenerated by looking for keywords in the first few lines of the file."""
|
|
keywords = ["auto-generated", "autogenerated", "automatically generated"]
|
|
lines = example["content"].splitlines()
|
|
for _, line in zip(range(scan_width), lines):
|
|
for keyword in keywords:
|
|
if keyword in line.lower():
|
|
return {"autogenerated": True}
|
|
else:
|
|
return {"autogenerated": False}
|
|
|
|
|
|
def preprocess(example):
|
|
"""Chain all preprocessing steps into one function to not fill cache."""
|
|
results = dict()
|
|
results.update(get_hash(example))
|
|
results.update(line_stats(example))
|
|
results.update(alpha_stats(example))
|
|
results.update(is_autogenerated(example))
|
|
return results
|
|
|
|
|
|
def filter(example, uniques, args):
|
|
"""Filter dataset with heuristics."""
|
|
if not check_uniques(example, uniques):
|
|
return False
|
|
elif example["autogenerated"]:
|
|
return False
|
|
elif example["line_max"] > args.line_max:
|
|
return False
|
|
elif example["line_mean"] > args.line_mean:
|
|
return False
|
|
elif example["alpha_frac"] < args.alpha_frac:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def compress_file(file_path):
|
|
"""Compress a file with g-zip."""
|
|
with open(file_path, "rb") as f_in:
|
|
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
|
|
shutil.copyfileobj(f_in, f_out)
|
|
os.unlink(file_path)
|
|
|
|
|
|
# Settings
|
|
parser = HfArgumentParser(PreprocessingArguments)
|
|
args = parser.parse_args()
|
|
if args.num_workers is None:
|
|
args.num_workers = multiprocessing.cpu_count()
|
|
|
|
# Load dataset
|
|
t_start = time.time()
|
|
ds = load_dataset(args.dataset_name, split="train")
|
|
print(f"Time to load dataset: {time.time()-t_start:.2f}")
|
|
|
|
# Run preprocessing
|
|
t_start = time.time()
|
|
ds = ds.map(preprocess, num_proc=args.num_workers)
|
|
print(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
|
|
|
|
# Deduplicate hashes
|
|
uniques = set(ds.unique("hash"))
|
|
frac = len(uniques) / len(ds)
|
|
print(f"Fraction of duplicates: {1-frac:.2%}")
|
|
|
|
# Deduplicate data and apply heuristics
|
|
t_start = time.time()
|
|
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
|
|
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
|
print(f"Size of filtered dataset: {len(ds_filter)}")
|
|
|
|
# Save data in batches of samples_per_file
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
t_start = time.time()
|
|
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
|
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
|
|
end_index = min(len(ds_filter), index + args.samples_per_file)
|
|
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
|
compress_file(file_path)
|
|
print(f"Time to save dataset: {time.time()-t_start:.2f}")
|