import unittest import numpy as np from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): def test_no_intersection(self): prompt = np.array([[1, 2, 3]]) prompt_plus_new_tokens = np.array([[4, 5, 6]]) result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) self.assertEqual(result, (None, None, None)) def test_complete_overlap(self): prompt = np.array([[1, 2, 3]]) prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( prompt, prompt_plus_new_tokens ) self.assertEqual(discrep_length, 0) np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) np.testing.assert_array_equal(discrep_only, np.array([[]])) def test_partial_overlap(self): prompt = np.array([[1, 2, 3]]) prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( prompt, prompt_plus_new_tokens ) self.assertEqual(discrep_length, 0) np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) np.testing.assert_array_equal(discrep_only, np.array([[]])) def test_no_new_tokens(self): prompt = np.array([[1, 2, 3]]) prompt_plus_new_tokens = np.array([[1, 2, 3]]) discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( prompt, prompt_plus_new_tokens ) self.assertEqual(discrep_length, 0) np.testing.assert_array_equal(new_tokens_only, np.array([[]])) np.testing.assert_array_equal(discrep_only, np.array([[]]))