transformers/examples/research_projects/codeparrot/scripts/preprocessing.py
Jia LI da2bd2ae96
[CodeParrot] Near-deduplication with jaccard similarity (#17054)
* deduplication draft

* update style

* update style test

* dummy test main

* rename modules

* rename functions

* return extremes in deduplicate_clusters

* update style

* cast str for gzip

* update doc string

* time processing

* use dataset map to compute minhash

* fill value for short token

* remove da map method

* update style

* use share object to multiprocess

* update style

* use f-string and minor fix

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* update style

* use module parameters

* change ds_dedup to ds_filter

* save ds_dedup

* mv test to script tests

* make jaccard threshold a parameter of deduplicate_dataset

* update style

* add doc strings

* update style

* add doc string for DuplicationIndex

* save files into data dir

* update readme

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* make near deduplication optional

* move near deduplication in README

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* use f string

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>
2022-06-21 14:23:36 +02:00

212 lines
7.2 KiB
Python

import gzip
import hashlib
import json
import multiprocessing
import os
import shutil
import time
from pathlib import Path
import numpy as np
from datasets import load_dataset
from arguments import PreprocessingArguments
from minhash_deduplication import deduplicate_dataset
from transformers import AutoTokenizer, HfArgumentParser
def get_hash(example):
"""Get hash of content field."""
return {"hash": hashlib.md5(example["content"].strip().encode("utf-8")).hexdigest()}
def line_stats(example):
"""Calculates mean and max line length of file."""
line_lengths = [len(line) for line in example["content"].splitlines()]
return {"line_mean": np.mean(line_lengths), "line_max": max(line_lengths)}
def alpha_stats(example):
"""Calculates mean and max line length of file."""
alpha_frac = np.mean([c.isalnum() for c in example["content"]])
return {"alpha_frac": alpha_frac}
def check_uniques(example, uniques):
"""Check if current hash is still in set of unique hashes and remove if true."""
if example["hash"] in uniques:
uniques.remove(example["hash"])
return True
else:
return False
def is_autogenerated(example, scan_width=5):
"""Check if file is autogenerated by looking for keywords in the first few lines of the file."""
keywords = ["auto-generated", "autogenerated", "automatically generated"]
lines = example["content"].splitlines()
for _, line in zip(range(scan_width), lines):
for keyword in keywords:
if keyword in line.lower():
return {"autogenerated": True}
else:
return {"autogenerated": False}
def is_config_or_test(example, scan_width=5, coeff=0.05):
"""Check if file is a configuration file or a unit test by :
1- looking for keywords in the first few lines of the file.
2- counting number of occurence of the words 'config' and 'test' with respect to number of lines.
"""
keywords = ["unit tests", "test file", "configuration file"]
lines = example["content"].splitlines()
count_config = 0
count_test = 0
# first test
for _, line in zip(range(scan_width), lines):
for keyword in keywords:
if keyword in line.lower():
return {"config_or_test": True}
# second test
nlines = example["content"].count("\n")
threshold = int(coeff * nlines)
for line in lines:
count_config += line.lower().count("config")
count_test += line.lower().count("test")
if count_config > threshold or count_test > threshold:
return {"config_or_test": True}
return {"config_or_test": False}
def has_no_keywords(example):
"""Check if a python file has none of the keywords for: funcion, class, for loop, while loop."""
keywords = ["def ", "class ", "for ", "while "]
lines = example["content"].splitlines()
for line in lines:
for keyword in keywords:
if keyword in line.lower():
return {"has_no_keywords": False}
return {"has_no_keywords": True}
def has_few_assignments(example, minimum=4):
"""Check if file uses symbol '=' less than `minimum` times."""
lines = example["content"].splitlines()
counter = 0
for line in lines:
counter += line.lower().count("=")
if counter > minimum:
return {"has_few_assignments": False}
return {"has_few_assignments": True}
def char_token_ratio(example):
"""Compute character/token ratio of the file with tokenizer."""
input_ids = tokenizer(example["content"], truncation=False)["input_ids"]
ratio = len(example["content"]) / len(input_ids)
return {"ratio": ratio}
def preprocess(example):
"""Chain all preprocessing steps into one function to not fill cache."""
results = dict()
results.update(get_hash(example))
results.update(line_stats(example))
results.update(alpha_stats(example))
results.update(char_token_ratio(example))
results.update(is_autogenerated(example))
results.update(is_config_or_test(example))
results.update(has_no_keywords(example))
results.update(has_few_assignments(example))
return results
def filter(example, uniques, args):
"""Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability."""
if not check_uniques(example, uniques):
return False
elif example["autogenerated"]:
return False
elif example["line_max"] > args.line_max:
return False
elif example["line_mean"] > args.line_mean:
return False
elif example["alpha_frac"] < args.alpha_frac:
return False
elif example["ratio"] < args.min_token_ratio:
return False
elif example["config_or_test"] and np.random.rand() <= args.filter_proba:
return False
elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba:
return False
elif example["has_few_assignments"]:
return False
else:
return True
def compress_file(file_path):
"""Compress a file with g-zip."""
with open(file_path, "rb") as f_in:
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(file_path)
# Settings
parser = HfArgumentParser(PreprocessingArguments)
args = parser.parse_args()
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
# Load dataset
t_start = time.time()
ds = load_dataset(args.dataset_name, split="train")
print(f"Time to load dataset: {time.time()-t_start:.2f}")
# Run preprocessing
t_start = time.time()
ds = ds.map(preprocess, num_proc=args.num_workers)
print(f"Time to preprocess dataset: {time.time()-t_start:.2f}")
# Deduplicate hashes
uniques = set(ds.unique("hash"))
frac = len(uniques) / len(ds)
print(f"Fraction of duplicates: {1-frac:.2%}")
# Deduplicate data and apply heuristics
t_start = time.time()
ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
print(f"Size of filtered dataset: {len(ds_filter)}")
# Deduplicate with minhash and jaccard similarity
if args.near_deduplication:
t_start = time.time()
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
print(f"Size of deduplicate dataset: {len(ds_filter)}")
# Save data in batches of samples_per_file
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
# save duplicate_clusters in the output_dir as artifacts
# not sure it is the right place the save it
if args.near_deduplication:
with open(output_dir / "duplicate_clusters.json", "w") as f:
json.dump(duplicate_clusters, f)
data_dir = output_dir / "data"
data_dir.mkdir(exist_ok=True)
t_start = time.time()
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
file_path = str(data_dir / f"file-{file_number+1:012}.json")
end_index = min(len(ds_filter), index + args.samples_per_file)
ds_filter.select(list(range(index, end_index))).to_json(file_path)
compress_file(file_path)
print(f"Time to save dataset: {time.time()-t_start:.2f}")