transformers/tests/generation/test_candidate_generator.py
Nadav Timor e3ee49fcfb
Refactoring AssistedCandidateGenerator for Improved Modularity and Reusability (#35009)
* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file

* refactor

* NOTHING. add space to rerun github actions tests

* remove it...

* NOTHING. add space to rerun github actions tests

* remove it...

* replace: `self.prev_tokens` -> `self.prev_assistant_ids`

* NOTHING. rerun CI tests

* remove it

* introduce `self.prev_target_ids_len`

* fix style

* fix style

---------

Co-authored-by: Jonathan Mamou <jonathan.mamou@intel.com>
2024-12-12 15:47:05 +01:00

44 lines
1.9 KiB
Python

import unittest
import numpy as np
from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))
def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))
def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))