Generate: nudge towards do_sample=False when temperature=0.0 (#25722)

This commit is contained in:
Joao Gante 2023-08-24 14:15:43 +01:00 committed by GitHub
parent 584eeb5387
commit 0a365c3e6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -266,7 +266,13 @@ class TemperatureLogitsWarper(LogitsWarper):
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
self.temperature = temperature