diff --git a/examples/run_seq2seq_finetuning_test.py b/examples/run_seq2seq_finetuning_test.py index e59f016da4d..77dc58666cb 100644 --- a/examples/run_seq2seq_finetuning_test.py +++ b/examples/run_seq2seq_finetuning_test.py @@ -21,43 +21,21 @@ class DataLoaderTest(unittest.TestCase): def setUp(self): self.block_size = 10 - def test_truncate_source_and_target_too_small(self): - """ When the sum of the lengths of the source and target sequences is - smaller than the block size (minus the number of special tokens), skip the example. """ - src_seq = [1, 2, 3, 4] - tgt_seq = [5, 6] - self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) + def test_truncate_sequence_too_small(self): + """ Pad the sequence with 0 if the sequence is smaller than the block size.""" + sequence = [1, 2, 3, 4] + expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0] + self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) - def test_truncate_source_and_target_fit_exactly(self): - """ When the sum of the lengths of the source and target sequences is - equal to the block size (minus the number of special tokens), return the - sequences unchanged. """ - src_seq = [1, 2, 3, 4] - tgt_seq = [5, 6, 7] - fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(src_seq, fitted_src) - self.assertListEqual(tgt_seq, fitted_tgt) + def test_truncate_sequence_fit_exactly(self): + sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) - def test_truncate_source_too_big_target_ok(self): - src_seq = [1, 2, 3, 4, 5, 6] - tgt_seq = [1, 2] - fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) - self.assertListEqual(fitted_tgt, fitted_tgt) - - def test_truncate_target_too_big_source_ok(self): - src_seq = [1, 2, 3, 4] - tgt_seq = [1, 2, 3, 4] - fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(fitted_src, src_seq) - self.assertListEqual(fitted_tgt, [1, 2, 3]) - - def test_truncate_source_and_target_too_big(self): - src_seq = [1, 2, 3, 4, 5, 6, 7] - tgt_seq = [1, 2, 3, 4, 5, 6, 7] - fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) - self.assertListEqual(fitted_tgt, [1, 2]) + def test_truncate_sequence_too_big(self): + sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) def test_process_story_no_highlights(self): """ Processing a story with no highlights should raise an exception.