mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201)
fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch Co-authored-by: chenhanhui <chenhanhui@kanzhun.com>
This commit is contained in:
parent
932ad8af7a
commit
4df1d69634
@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
||||
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
|
||||
|
||||
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
|
||||
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
|
||||
scores = torch.where(do_early_stop, early_stop_scores, scores)
|
||||
|
||||
return scores
|
||||
|
@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
def test_early_stop_processor_multi_eos(self):
|
||||
input_ids = None
|
||||
eos_token_id = [2, 3]
|
||||
min_eos_p = 0.1 ## some small float
|
||||
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
Loading…
Reference in New Issue
Block a user