mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[CodeParrot] Near-deduplication with jaccard similarity (#17054)
* deduplication draft * update style * update style test * dummy test main * rename modules * rename functions * return extremes in deduplicate_clusters * update style * cast str for gzip * update doc string * time processing * use dataset map to compute minhash * fill value for short token * remove da map method * update style * use share object to multiprocess * update style * use f-string and minor fix Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> * update style * use module parameters * change ds_dedup to ds_filter * save ds_dedup * mv test to script tests * make jaccard threshold a parameter of deduplicate_dataset * update style * add doc strings * update style * add doc string for DuplicationIndex * save files into data dir * update readme * Update examples/research_projects/codeparrot/README.md Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> * make near deduplication optional * move near deduplication in README * Update examples/research_projects/codeparrot/README.md Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * use f string Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>
This commit is contained in:
parent
eb16be415a
commit
da2bd2ae96
@ -40,6 +40,7 @@ The source of the dataset is the GitHub dump available on Google's [BigQuery](ht
|
||||
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
|
||||
|
||||
- exact deduplication using each file's hash
|
||||
- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description.
|
||||
- filtering files with max line length > 1000
|
||||
- filtering files with mean line length > 100
|
||||
- fraction of alphanumeric characters < 0.25
|
||||
|
@ -4,4 +4,6 @@ wandb==0.12.0
|
||||
tensorboard==2.6.0
|
||||
torch==1.11.0
|
||||
huggingface-hub==0.1.0
|
||||
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
||||
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
||||
datasketch==1.5.7
|
||||
dpu_utils
|
@ -157,6 +157,12 @@ class PreprocessingArguments:
|
||||
default="lvwerra/codeparrot",
|
||||
metadata={"help": "Name or path to the tokenizer."},
|
||||
)
|
||||
near_deduplication: Optional[bool] = field(
|
||||
default=False, metadata={"help": "If True, near-duplicate samples are removed."}
|
||||
)
|
||||
jaccard_threshold: Optional[float] = field(
|
||||
default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -0,0 +1,270 @@
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
from dpu_utils.utils.iterators import ThreadedIterator
|
||||
|
||||
|
||||
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||
# parameters used in DuplicationIndex
|
||||
MIN_NUM_TOKENS = 10
|
||||
NUM_PERM = 256
|
||||
|
||||
|
||||
def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
|
||||
"""Compute the MinHash of a code snippet."""
|
||||
if len(tokens) < MIN_NUM_TOKENS:
|
||||
return None
|
||||
min_hash = MinHash(num_perm=NUM_PERM)
|
||||
for token in set(tokens):
|
||||
min_hash.update(token.encode())
|
||||
return min_hash
|
||||
|
||||
|
||||
def get_tokens(code: str) -> Set[str]:
|
||||
"""Tokenize a code snippet."""
|
||||
return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0])
|
||||
|
||||
|
||||
class DuplicationIndex:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
duplication_jaccard_threshold: float = 0.85,
|
||||
):
|
||||
self._duplication_jaccard_threshold = duplication_jaccard_threshold
|
||||
self._num_perm = NUM_PERM
|
||||
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
|
||||
|
||||
self._duplicate_clusters = defaultdict(set)
|
||||
|
||||
def add(self, code_key: Tuple, min_hash: MinHash) -> None:
|
||||
"""Add a key to _index (MinHashLSH)
|
||||
the min_hash is used to query closest matches based on the jaccard_threshold.
|
||||
The new key is either added to a existing cluster of one close match,
|
||||
or a new cluster is created. The clusters created in this way, depend on the order of add.
|
||||
|
||||
Args:
|
||||
code_key (Tuple of (index, repo_name, path)):
|
||||
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
|
||||
min_hash: MinHash of the code_key.
|
||||
"""
|
||||
close_duplicates = self._index.query(min_hash)
|
||||
if code_key in self._index.keys:
|
||||
print(f"Duplicate key {code_key}")
|
||||
return
|
||||
|
||||
self._index.insert(code_key, min_hash)
|
||||
if len(close_duplicates) > 0:
|
||||
|
||||
for base_duplicate in close_duplicates:
|
||||
if base_duplicate in self._duplicate_clusters:
|
||||
self._duplicate_clusters[base_duplicate].add(code_key)
|
||||
break
|
||||
else:
|
||||
self._duplicate_clusters[close_duplicates[0]].add(code_key)
|
||||
|
||||
def get_duplicate_clusters(self) -> List[List[Dict]]:
|
||||
"""Export the duplicate clusters.
|
||||
For each cluster, the first element is the base element of the cluster.
|
||||
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
|
||||
|
||||
Returns:
|
||||
duplicate_clusters (List[List[Dict]]):
|
||||
List of duplicate clusters.
|
||||
"""
|
||||
duplicate_clusters = []
|
||||
for base, duplicates in self._duplicate_clusters.items():
|
||||
cluster = [base] + list(duplicates)
|
||||
# reformat the cluster to be a list of dict
|
||||
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
|
||||
duplicate_clusters.append(cluster)
|
||||
return duplicate_clusters
|
||||
|
||||
def save(self, filepath) -> None:
|
||||
duplicate_clusters = self.get_duplicate_clusters()
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(duplicate_clusters, f)
|
||||
|
||||
|
||||
def _compute_min_hash(element):
|
||||
index, data = element
|
||||
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
|
||||
if min_hash is not None:
|
||||
return (index, data["repo_name"], data["path"]), min_hash
|
||||
|
||||
|
||||
def minhash_iter(dataset_iterator: Type[Dataset]):
|
||||
with mp.Pool() as pool:
|
||||
for data in pool.imap_unordered(
|
||||
_compute_min_hash,
|
||||
ThreadedIterator(dataset_iterator, max_queue_size=10000),
|
||||
chunksize=100,
|
||||
):
|
||||
if data is not None:
|
||||
yield data
|
||||
|
||||
|
||||
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
|
||||
"""Find duplicate clusters in the dataset in two steps:
|
||||
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
|
||||
This step is computed using an asynchronous multiprocessing pool, minhash_iter
|
||||
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
|
||||
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
|
||||
"""
|
||||
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
|
||||
|
||||
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
|
||||
di.add(filename, min_hash)
|
||||
|
||||
# Returns a List[Cluster] where Cluster is List[str] with the filenames.
|
||||
return di.get_duplicate_clusters()
|
||||
|
||||
|
||||
def jaccard_similarity(code1: str, code2: str) -> float:
|
||||
"""Compute the Jaccard similarity of two code snippets."""
|
||||
tokens1 = get_tokens(code1)
|
||||
tokens2 = get_tokens(code2)
|
||||
return len(tokens1 & tokens2) / len(tokens1 | tokens2)
|
||||
|
||||
|
||||
_shared_dataset = None
|
||||
|
||||
|
||||
def _find_cluster_extremes_shared(cluster, jaccard_threshold):
|
||||
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
|
||||
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||
|
||||
Args:
|
||||
cluster (List[dict]):
|
||||
cluster is a list of dict, each dict contains the following keys:
|
||||
- base_index
|
||||
- repo_name
|
||||
- path
|
||||
This is a typical output of DuplicationIndex.get_duplicate_clusters()
|
||||
jaccard_threshold (float):
|
||||
threshold for Jaccard similarity.
|
||||
Two codes are similar if their Jaccard similarity is above the threshold.
|
||||
|
||||
Returns:
|
||||
extremes (List[dict]):
|
||||
A reduced representation of the cluster. The field copies is added to each dict.
|
||||
The copies field indicates the number of similar codes in the cluster for a extreme.
|
||||
"""
|
||||
extremes = []
|
||||
for element1 in cluster:
|
||||
code1 = _shared_dataset[element1["base_index"]]["content"]
|
||||
for element2 in extremes:
|
||||
code2 = _shared_dataset[element2["base_index"]]["content"]
|
||||
if jaccard_similarity(code1, code2) >= jaccard_threshold:
|
||||
element2["copies"] += 1
|
||||
break
|
||||
else:
|
||||
element1["copies"] = 1
|
||||
extremes.append(element1)
|
||||
return extremes
|
||||
|
||||
|
||||
def find_extremes(cluster_list, dataset, jaccard_threshold):
|
||||
"""Call the _find_cluster_extremes_shared function in a parallel fashion.
|
||||
|
||||
Args:
|
||||
cluster_list (List[List[Dict]]):
|
||||
each cluster is a list of dicts with the key base_index,
|
||||
referring to the index of the base code in the dataset.
|
||||
dataset (Type[Dataset]):
|
||||
dataset is used to access the content of the code snippets,
|
||||
using the base_index from the cluster_list.
|
||||
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
|
||||
otherwise the multi processing is not speeded up.
|
||||
jaccard_threshold (float):
|
||||
the threshold for the jaccard similarity. The default value is 0.85
|
||||
|
||||
Returns:
|
||||
extremes_list (List[Dict]):
|
||||
Each cluster is reduced to extremes.
|
||||
See _find_cluster_extremes_shared for the definition of extremes.
|
||||
"""
|
||||
global _shared_dataset
|
||||
_shared_dataset = dataset
|
||||
extremes_list = []
|
||||
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
|
||||
with mp.Pool() as pool:
|
||||
for extremes in tqdm(
|
||||
pool.imap_unordered(
|
||||
f,
|
||||
cluster_list,
|
||||
),
|
||||
total=len(cluster_list),
|
||||
):
|
||||
extremes_list.append(extremes)
|
||||
return extremes_list
|
||||
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Type[Dataset], jaccard_threshold: float = 0.85
|
||||
) -> Tuple[Type[Dataset], List[List[Dict]]]:
|
||||
"""Deduplicate the dataset using minhash and jaccard similarity.
|
||||
This function first generate duplicate clusters, then each cluster
|
||||
is reduced to the extremes that are similar to the other elements in the cluster.
|
||||
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
|
||||
|
||||
Args:
|
||||
dataset (Type[Dataset]):
|
||||
The dataset to deduplicate.
|
||||
jaccard_threshold (float, default=0.85):
|
||||
jaccard threshold to determine if two codes are similar
|
||||
|
||||
Returns:
|
||||
ds_dedup (Type[Dataset]):
|
||||
The deduplicated dataset.
|
||||
duplicate_clusters (List[List[Dict]]):
|
||||
The list of duplicate clusters.
|
||||
Each cluster is a list of dicts with the following keys:
|
||||
- base_index : int
|
||||
The index of the code in the original dataset.
|
||||
- repo_name : str
|
||||
- path : str
|
||||
- copies : int
|
||||
The number of copies of the code in the cluster. (find_cluster_extremes)
|
||||
- is_extreme : bool
|
||||
Whether the code is an extreme in the cluster.
|
||||
All the codes in the cluster are removed from the dataset except the extremes.
|
||||
|
||||
Example:
|
||||
>>> from datasets import load_dataset
|
||||
>>> from minhash_deduplication import deduplicate_dataset
|
||||
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
|
||||
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
|
||||
"""
|
||||
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
|
||||
duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster)
|
||||
extreme_dict = {}
|
||||
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
|
||||
for extremes in extremes_clusters:
|
||||
for element in extremes:
|
||||
extreme_dict[element["base_index"]] = element
|
||||
remove_indices = duplicate_indices - set(extreme_dict.keys())
|
||||
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
|
||||
|
||||
# update duplicate_clusters
|
||||
for cluster in duplicate_clusters:
|
||||
for element in cluster:
|
||||
element["is_extreme"] = element["base_index"] in extreme_dict
|
||||
if element["is_extreme"]:
|
||||
element["copies"] = extreme_dict[element["base_index"]]["copies"]
|
||||
|
||||
print(f"Original dataset size: {len(dataset)}")
|
||||
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
|
||||
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
|
||||
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
|
||||
print(f"Filtered dataset size: {len(ds_filter)}")
|
||||
|
||||
return ds_filter, duplicate_clusters
|
@ -1,14 +1,17 @@
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PreprocessingArguments
|
||||
from minhash_deduplication import deduplicate_dataset
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@ -146,7 +149,7 @@ def filter(example, uniques, args):
|
||||
def compress_file(file_path):
|
||||
"""Compress a file with g-zip."""
|
||||
with open(file_path, "rb") as f_in:
|
||||
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
|
||||
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(file_path)
|
||||
|
||||
@ -179,12 +182,29 @@ ds_filter = ds.filter(filter, fn_kwargs={"uniques": uniques, "args": args})
|
||||
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
|
||||
print(f"Size of filtered dataset: {len(ds_filter)}")
|
||||
|
||||
# Deduplicate with minhash and jaccard similarity
|
||||
if args.near_deduplication:
|
||||
t_start = time.time()
|
||||
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
|
||||
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
|
||||
print(f"Size of deduplicate dataset: {len(ds_filter)}")
|
||||
|
||||
# Save data in batches of samples_per_file
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# save duplicate_clusters in the output_dir as artifacts
|
||||
# not sure it is the right place the save it
|
||||
if args.near_deduplication:
|
||||
with open(output_dir / "duplicate_clusters.json", "w") as f:
|
||||
json.dump(duplicate_clusters, f)
|
||||
|
||||
data_dir = output_dir / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
|
||||
t_start = time.time()
|
||||
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
|
||||
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
|
||||
file_path = str(data_dir / f"file-{file_number+1:012}.json")
|
||||
end_index = min(len(ds_filter), index + args.samples_per_file)
|
||||
ds_filter.select(list(range(index, end_index))).to_json(file_path)
|
||||
compress_file(file_path)
|
||||
|
@ -0,0 +1,30 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||
|
||||
|
||||
def get_dataset():
|
||||
data_dict = {
|
||||
"repo_name": ["test_repo1", "test_repo2", "test_repo3"],
|
||||
"path": ["test_1.py", "test_2.py", "unit_test.py"],
|
||||
"content": ["a " * 20, "a " * 30, "b " * 7],
|
||||
}
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
return dataset
|
||||
|
||||
|
||||
class MakeDuplicateClustersTest(TestCase):
|
||||
def test_make_duplicate_clusters(self):
|
||||
ds = get_dataset()
|
||||
duplicate_clusters = make_duplicate_clusters(ds, 0.85)
|
||||
self.assertEqual(len(duplicate_clusters[0]), 2)
|
||||
|
||||
def test_deduplicate_dataset(self):
|
||||
ds = get_dataset()
|
||||
ds_filter, duplicate_clusters = deduplicate_dataset(ds)
|
||||
self.assertEqual(len(ds_filter), 2)
|
||||
print(duplicate_clusters)
|
||||
self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
|
||||
self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)
|
Loading…
Reference in New Issue
Block a user