transformers/examples/research_projects/codeparrot/scripts/tests/test_deduplicate.py
Jia LI da2bd2ae96
[CodeParrot] Near-deduplication with jaccard similarity (#17054)
* deduplication draft

* update style

* update style test

* dummy test main

* rename modules

* rename functions

* return extremes in deduplicate_clusters

* update style

* cast str for gzip

* update doc string

* time processing

* use dataset map to compute minhash

* fill value for short token

* remove da map method

* update style

* use share object to multiprocess

* update style

* use f-string and minor fix

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* update style

* use module parameters

* change ds_dedup to ds_filter

* save ds_dedup

* mv test to script tests

* make jaccard threshold a parameter of deduplicate_dataset

* update style

* add doc strings

* update style

* add doc string for DuplicationIndex

* save files into data dir

* update readme

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>

* make near deduplication optional

* move near deduplication in README

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* use f string

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com>
2022-06-21 14:23:36 +02:00

31 lines
1005 B
Python

from unittest import TestCase
from datasets import Dataset
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
def get_dataset():
data_dict = {
"repo_name": ["test_repo1", "test_repo2", "test_repo3"],
"path": ["test_1.py", "test_2.py", "unit_test.py"],
"content": ["a " * 20, "a " * 30, "b " * 7],
}
dataset = Dataset.from_dict(data_dict)
return dataset
class MakeDuplicateClustersTest(TestCase):
def test_make_duplicate_clusters(self):
ds = get_dataset()
duplicate_clusters = make_duplicate_clusters(ds, 0.85)
self.assertEqual(len(duplicate_clusters[0]), 2)
def test_deduplicate_dataset(self):
ds = get_dataset()
ds_filter, duplicate_clusters = deduplicate_dataset(ds)
self.assertEqual(len(ds_filter), 2)
print(duplicate_clusters)
self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)