Generate: PT's top_p enforces min_tokens_to_keep when it is 1 (#24111)

This commit is contained in:
Joao Gante 2023-06-09 13:20:05 +01:00 committed by GitHub
parent 03585f3734
commit be10092e63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 3 deletions

View File

@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value

View File

@ -255,6 +255,8 @@ class TopPLogitsWarper(LogitsWarper):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
@ -266,9 +268,8 @@ class TopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

View File

@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value