[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201)

fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch

Co-authored-by: chenhanhui <chenhanhui@kanzhun.com>
This commit is contained in:
HanHui 2024-01-10 18:46:49 +08:00 committed by GitHub
parent 932ad8af7a
commit 4df1d69634
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 0 deletions

View File

@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
scores = torch.where(do_early_stop, early_stop_scores, scores)
return scores

View File

@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase):
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
def test_early_stop_processor_multi_eos(self):
input_ids = None
eos_token_id = [2, 3]
min_eos_p = 0.1 ## some small float
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)