mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Generate: PT's top_p
enforces min_tokens_to_keep
when it is 1
(#24111)
This commit is contained in:
parent
03585f3734
commit
be10092e63
@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
|
||||
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
|
@ -255,6 +255,8 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
top_p = float(top_p)
|
||||
if top_p < 0 or top_p > 1.0:
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
|
||||
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
@ -266,9 +268,8 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
|
||||
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
|
Loading…
Reference in New Issue
Block a user