mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[CodeParrot] Near-deduplication with jaccard similarity (#17054)
* deduplication draft * update style * update style test * dummy test main * rename modules * rename functions * return extremes in deduplicate_clusters * update style * cast str for gzip * update doc string * time processing * use dataset map to compute minhash * fill value for short token * remove da map method * update style * use share object to multiprocess * update style * use f-string and minor fix Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> * update style * use module parameters * change ds_dedup to ds_filter * save ds_dedup * mv test to script tests * make jaccard threshold a parameter of deduplicate_dataset * update style * add doc strings * update style * add doc string for DuplicationIndex * save files into data dir * update readme * Update examples/research_projects/codeparrot/README.md Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> * make near deduplication optional * move near deduplication in README * Update examples/research_projects/codeparrot/README.md Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * use f string Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>
This commit is contained in:
parent
eb16be415a
commit
da2bd2ae96
@ -40,6 +40,7 @@ The source of the dataset is the GitHub dump available on Google's [BigQuery](ht
|
|||||||
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
|
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
|
||||||
|
|
||||||
- exact deduplication using each file's hash
|
- exact deduplication using each file's hash
|
||||||
|
- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description.
|
||||||
- filtering files with max line length > 1000
|
- filtering files with max line length > 1000
|
||||||
- filtering files with mean line length > 100
|
- filtering files with mean line length > 100
|
||||||
- fraction of alphanumeric characters < 0.25
|
- fraction of alphanumeric characters < 0.25
|
||||||
|
@ -4,4 +4,6 @@ wandb==0.12.0
|
|||||||
tensorboard==2.6.0
|
tensorboard==2.6.0
|
||||||
torch==1.11.0
|
torch==1.11.0
|
||||||
huggingface-hub==0.1.0
|
huggingface-hub==0.1.0
|
||||||
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
||||||
|
datasketch==1.5.7
|
||||||
|
dpu_utils
|
@ -157,6 +157,12 @@ class PreprocessingArguments:
|
|||||||
default="lvwerra/codeparrot",
|
default="lvwerra/codeparrot",
|
||||||
metadata={"help": "Name or path to the tokenizer."},
|
metadata={"help": "Name or path to the tokenizer."},
|
||||||
)
|
)
|
||||||
|
near_deduplication: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "If True, near-duplicate samples are removed."}
|
||||||
|
)
|
||||||
|
jaccard_threshold: Optional[float] = field(
|
||||||
|
default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -0,0 +1,270 @@
|
|||||||
|
import json
|
||||||
|
import multiprocessing as mp
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from datasketch import MinHash, MinHashLSH
|
||||||
|
from dpu_utils.utils.iterators import ThreadedIterator
|
||||||
|
|
||||||
|
|
||||||
|
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||||
|
# parameters used in DuplicationIndex
|
||||||
|
MIN_NUM_TOKENS = 10
|
||||||
|
NUM_PERM = 256
|
||||||
|
|
||||||
|
|
||||||
|
def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
|
||||||
|
"""Compute the MinHash of a code snippet."""
|
||||||
|
if len(tokens) < MIN_NUM_TOKENS:
|
||||||
|
return None
|
||||||
|
min_hash = MinHash(num_perm=NUM_PERM)
|
||||||
|
for token in set(tokens):
|
||||||
|
min_hash.update(token.encode())
|
||||||
|
return min_hash
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens(code: str) -> Set[str]:
|
||||||
|
"""Tokenize a code snippet."""
|
||||||
|
return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0])
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicationIndex:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
duplication_jaccard_threshold: float = 0.85,
|
||||||
|
):
|
||||||
|
self._duplication_jaccard_threshold = duplication_jaccard_threshold
|
||||||
|
self._num_perm = NUM_PERM
|
||||||
|
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
|
||||||
|
|
||||||
|
self._duplicate_clusters = defaultdict(set)
|
||||||
|
|
||||||
|
def add(self, code_key: Tuple, min_hash: MinHash) -> None:
|
||||||
|
"""Add a key to _index (MinHashLSH)
|
||||||
|
the min_hash is used to query closest matches based on the jaccard_threshold.
|
||||||
|
The new key is either added to a existing cluster of one close match,
|
||||||
|
or a new cluster is created. The clusters created in this way, depend on the order of add.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_key (Tuple of (index, repo_name, path)):
|
||||||
|
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
|
||||||
|
min_hash: MinHash of the code_key.
|
||||||
|
"""
|
||||||
|
close_duplicates = self._index.query(min_hash)
|
||||||
|
if code_key in self._index.keys:
|
||||||
|
print(f"Duplicate key {code_key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._index.insert(code_key, min_hash)
|
||||||
|
if len(close_duplicates) > 0:
|
||||||
|
|
||||||
|
for base_duplicate in close_duplicates:
|
||||||
|
if base_duplicate in self._duplicate_clusters:
|
||||||
|
self._duplicate_clusters[base_duplicate].add(code_key)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self._duplicate_clusters[close_duplicates[0]].add(code_key)
|
||||||
|
|
||||||
|
def get_duplicate_clusters(self) -> List[List[Dict]]:
|
||||||
|
"""Export the duplicate clusters.
|
||||||
|
For each cluster, the first element is the base element of the cluster.
|
||||||
|
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
duplicate_clusters (List[List[Dict]]):
|
||||||
|
List of duplicate clusters.
|
||||||
|
"""
|
||||||
|
duplicate_clusters = []
|
||||||
|
for base, duplicates in self._duplicate_clusters.items():
|
||||||
|
cluster = [base] + list(duplicates)
|
||||||
|
# reformat the cluster to be a list of dict
|
||||||
|
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
|
||||||
|
duplicate_clusters.append(cluster)
|
||||||
|
return duplicate_clusters
|
||||||
|
|
||||||
|
def save(self, filepath) -> None:
|
||||||
|
duplicate_clusters = self.get_duplicate_clusters()
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
json.dump(duplicate_clusters, f)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_min_hash(element):
|
||||||
|
index, data = element
|
||||||
|
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
|
||||||
|
if min_hash is not None:
|
||||||
|
return (index, data["repo_name"], data["path"]), min_hash
|
||||||
|
|
||||||
|
|
||||||
|
def minhash_iter(dataset_iterator: Type[Dataset]):
|
||||||
|
with mp.Pool() as pool:
|
||||||
|
for data in pool.imap_unordered(
|
||||||
|
_compute_min_hash,
|
||||||
|
ThreadedIterator(dataset_iterator, max_queue_size=10000),
|
||||||
|
chunksize=100,
|
||||||
|
):
|
||||||
|
if data is not None:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
|
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
|
||||||
|
"""Find duplicate clusters in the dataset in two steps:
|
||||||
|
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
|
||||||
|
This step is computed using an asynchronous multiprocessing pool, minhash_iter
|
||||||
|
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
|
||||||
|
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
|
||||||
|
"""
|
||||||
|
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
|
||||||
|
|
||||||
|
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
|
||||||
|
di.add(filename, min_hash)
|
||||||
|
|
||||||
|
# Returns a List[Cluster] where Cluster is List[str] with the filenames.
|
||||||
|
return di.get_duplicate_clusters()
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_similarity(code1: str, code2: str) -> float:
|
||||||
|
"""Compute the Jaccard similarity of two code snippets."""
|
||||||
|
tokens1 = get_tokens(code1)
|
||||||
|
tokens2 = get_tokens(code2)
|
||||||
|
return len(tokens1 & tokens2) / len(tokens1 | tokens2)
|
||||||
|
|
||||||
|
|
||||||
|
_shared_dataset = None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_cluster_extremes_shared(cluster, jaccard_threshold):
|
||||||
|
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
|
||||||
|
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cluster (List[dict]):
|
||||||
|
cluster is a list of dict, each dict contains the following keys:
|
||||||
|
- base_index
|
||||||
|
- repo_name
|
||||||
|
- path
|
||||||
|
This is a typical output of DuplicationIndex.get_duplicate_clusters()
|
||||||
|
jaccard_threshold (float):
|
||||||
|
threshold for Jaccard similarity.
|
||||||
|
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
extremes (List[dict]):
|
||||||
|
A reduced representation of the cluster. The field copies is added to each dict.
|
||||||
|
The copies field indicates the number of similar codes in the cluster for a extreme.
|
||||||
|
"""
|
||||||
|
extremes = []
|
||||||
|
for element1 in cluster:
|
||||||
|
code1 = _shared_dataset[element1["base_index"]]["content"]
|
||||||
|
for element2 in extremes:
|
||||||
|
code2 = _shared_dataset[element2["base_index"]]["content"]
|
||||||
|
if jaccard_similarity(code1, code2) >= jaccard_threshold:
|
||||||
|
element2["copies"] += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
element1["copies"] = 1
|
||||||
|
extremes.append(element1)
|
||||||
|
return extremes
|
||||||
|
|
||||||
|
|
||||||
|
def find_extremes(cluster_list, dataset, jaccard_threshold):
|
||||||
|
"""Call the _find_cluster_extremes_shared function in a parallel fashion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cluster_list (List[List[Dict]]):
|
||||||
|
each cluster is a list of dicts with the key base_index,
|
||||||
|
referring to the index of the base code in the dataset.
|
||||||
|
dataset (Type[Dataset]):
|
||||||
|
dataset is used to access the content of the code snippets,
|
||||||
|
using the base_index from the cluster_list.
|
||||||
|
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
|
||||||
|
otherwise the multi processing is not speeded up.
|
||||||
|
jaccard_threshold (float):
|
||||||
|
the threshold for the jaccard similarity. The default value is 0.85
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
extremes_list (List[Dict]):
|
||||||
|
Each cluster is reduced to extremes.
|
||||||
|
See _find_cluster_extremes_shared for the definition of extremes.
|
||||||
|
"""
|
||||||
|
global _shared_dataset
|
||||||
|
_shared_dataset = dataset
|
||||||
|
extremes_list = []
|
||||||
|
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
|
||||||
|
with mp.Pool() as pool:
|
||||||
|
for extremes in tqdm(
|
||||||
|
pool.imap_unordered(
|
||||||
|
f,
|
||||||
|
cluster_list,
|
||||||
|
),
|
||||||
|
total=len(cluster_list),
|
||||||
|
):
|
||||||
|
extremes_list.append(extremes)
|
||||||
|
return extremes_list
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_dataset(
|
||||||
|
dataset: Type[Dataset], jaccard_threshold: float = 0.85
|
||||||
|
) -> Tuple[Type[Dataset], List[List[Dict]]]:
|
||||||
|
"""Deduplicate the dataset using minhash and jaccard similarity.
|
||||||
|
This function first generate duplicate clusters, then each cluster
|
||||||
|
is reduced to the extremes that are similar to the other elements in the cluster.
|
||||||
|
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Type[Dataset]):
|
||||||
|
The dataset to deduplicate.
|
||||||
|
jaccard_threshold (float, default=0.85):
|
||||||
|
jaccard threshold to determine if two codes are similar
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ds_dedup (Type[Dataset]):
|
||||||
|
The deduplicated dataset.
|
||||||
|
duplicate_clusters (List[List[Dict]]):
|
||||||
|
The list of duplicate clusters.
|
||||||
|
Each cluster is a list of dicts with the following keys:
|
||||||
|
- base_index : int
|
||||||
|
The index of the code in the original dataset.
|
||||||
|
- repo_name : str
|
||||||
|
- path : str
|
||||||
|
- copies : int
|
||||||
|
The number of copies of the code in the cluster. (find_cluster_extremes)
|
||||||
|
- is_extreme : bool
|
||||||
|
Whether the code is an extreme in the cluster.
|
||||||
|
All the codes in the cluster are removed from the dataset except the extremes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> from minhash_deduplication import deduplicate_dataset
|
||||||
|
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
|
||||||
|
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
|
||||||
|
"""
|
||||||
|
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
|
||||||
|
duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster)
|
||||||
|
extreme_dict = {}
|
||||||
|
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
|
||||||
|
for extremes in extremes_clusters:
|
||||||
|
for element in extremes:
|
||||||
|
extreme_dict[element["base_index"]] = element
|
||||||
|
remove_indices = duplicate_indices - set(extreme_dict.keys())
|
||||||
|
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
|
||||||
|
|
||||||
|
# update duplicate_clusters
|
||||||
|
for cluster in duplicate_clusters:
|
||||||
|
for element in cluster:
|
||||||
|
element["is_extreme"] = element["base_index"] in extreme_dict
|
||||||
|
if element["is_extreme"]:
|
||||||
|
element["copies"] = extreme_dict[element["base_index"]]["copies"]
|
||||||
|
|
||||||
|
print(f"Original dataset size: {len(dataset)}")
|
||||||
|
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
|
||||||
|
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
|
||||||
|
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
|
||||||
|
print(f"Filtered dataset size: {len(ds_filter)}")
|
||||||
|
|
||||||
|
return ds_filter, duplicate_clusters
|
@ -1,14 +1,17 @@
|
|||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from arguments import PreprocessingArguments
|
from arguments import PreprocessingArguments
|
||||||
|
from minhash_deduplication import deduplicate_dataset
|
||||||
from transformers import AutoTokenizer, HfArgumentParser
|
from transformers import AutoTokenizer, HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +149,7 @@ def filter(example, uniques, args):
|
|||||||
def compress_file(file_path):
|
def compress_file(file_path):
|
||||||
"""Compress a file with g-zip."""
|
"""Compress a file with g-zip."""
|
||||||
with open(file_path, "rb") as f_in:
|
with open(file_path, "rb") as f_in:
|
||||||
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
|
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
|
||||||
shutil.copyfileobj(f_in, f_out)
|
shutil.copyfileobj(f_in, f_out)
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
|
|
||||||
@ -179,12 +182,29 @@ ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
|
|||||||
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
||||||
print(f"Size of filtered dataset: {len(ds_filter)}")
|
print(f"Size of filtered dataset: {len(ds_filter)}")
|
||||||
|
|
||||||
|
# Deduplicate with minhash and jaccard similarity
|
||||||
|
if args.near_deduplication:
|
||||||
|
t_start = time.time()
|
||||||
|
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
|
||||||
|
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
|
||||||
|
print(f"Size of deduplicate dataset: {len(ds_filter)}")
|
||||||
|
|
||||||
# Save data in batches of samples_per_file
|
# Save data in batches of samples_per_file
|
||||||
if not os.path.exists(args.output_dir):
|
output_dir = Path(args.output_dir)
|
||||||
os.makedirs(args.output_dir)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# save duplicate_clusters in the output_dir as artifacts
|
||||||
|
# not sure it is the right place the save it
|
||||||
|
if args.near_deduplication:
|
||||||
|
with open(output_dir / "duplicate_clusters.json", "w") as f:
|
||||||
|
json.dump(duplicate_clusters, f)
|
||||||
|
|
||||||
|
data_dir = output_dir / "data"
|
||||||
|
data_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
||||||
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
|
file_path = str(data_dir / f"file-{file_number+1:012}.json")
|
||||||
end_index = min(len(ds_filter), index + args.samples_per_file)
|
end_index = min(len(ds_filter), index + args.samples_per_file)
|
||||||
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
||||||
compress_file(file_path)
|
compress_file(file_path)
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset():
|
||||||
|
data_dict = {
|
||||||
|
"repo_name": ["test_repo1", "test_repo2", "test_repo3"],
|
||||||
|
"path": ["test_1.py", "test_2.py", "unit_test.py"],
|
||||||
|
"content": ["a " * 20, "a " * 30, "b " * 7],
|
||||||
|
}
|
||||||
|
dataset = Dataset.from_dict(data_dict)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MakeDuplicateClustersTest(TestCase):
|
||||||
|
def test_make_duplicate_clusters(self):
|
||||||
|
ds = get_dataset()
|
||||||
|
duplicate_clusters = make_duplicate_clusters(ds, 0.85)
|
||||||
|
self.assertEqual(len(duplicate_clusters[0]), 2)
|
||||||
|
|
||||||
|
def test_deduplicate_dataset(self):
|
||||||
|
ds = get_dataset()
|
||||||
|
ds_filter, duplicate_clusters = deduplicate_dataset(ds)
|
||||||
|
self.assertEqual(len(ds_filter), 2)
|
||||||
|
print(duplicate_clusters)
|
||||||
|
self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
|
||||||
|
self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)
|
Loading…
Reference in New Issue
Block a user