mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix test of truncation function
This commit is contained in:
parent
a67413ccc8
commit
932543f77e
@ -21,43 +21,21 @@ class DataLoaderTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_truncate_source_and_target_too_small(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
smaller than the block size (minus the number of special tokens), skip the example. """
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [5, 6]
|
||||
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
||||
def test_truncate_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_truncate_source_and_target_fit_exactly(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
equal to the block size (minus the number of special tokens), return the
|
||||
sequences unchanged. """
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [5, 6, 7]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(src_seq, fitted_src)
|
||||
self.assertListEqual(tgt_seq, fitted_tgt)
|
||||
def test_truncate_sequence_fit_exactly(self):
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_truncate_source_too_big_target_ok(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6]
|
||||
tgt_seq = [1, 2]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, fitted_tgt)
|
||||
|
||||
def test_truncate_target_too_big_source_ok(self):
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [1, 2, 3, 4]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, src_seq)
|
||||
self.assertListEqual(fitted_tgt, [1, 2, 3])
|
||||
|
||||
def test_truncate_source_and_target_too_big(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, [1, 2])
|
||||
def test_truncate_sequence_too_big(self):
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights should raise an exception.
|
||||
|
Loading…
Reference in New Issue
Block a user