mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix top_k_top_p_filtering
having unexpected behavior (#17744)
- Fix `top_k_top_p_filtering` not passing `filter_value` to `TopPLogitsWarper` causing any top-p filtered logits to be -inf instead of specified value - Add corresponding test
This commit is contained in:
parent
3ccff0d400
commit
3b00b623b7
@ -3347,6 +3347,8 @@ def top_k_top_p_filtering(
|
||||
)
|
||||
|
||||
if 0 <= top_p <= 1.0:
|
||||
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
|
||||
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
return logits
|
||||
|
@ -1626,6 +1626,32 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
# tests whether the function uses filter_value instead of default -inf
|
||||
def test_top_k_top_p_filtering_with_filter_value(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0.99, # get filtered by top-p filtering
|
||||
0.98, # get filtered by top-k filtering
|
||||
]
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
expected_output = torch.tensor(
|
||||
[[1, 1, 1, 0, 0]],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
|
||||
|
||||
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user