[whisper] Clarify error message when setting max_new_tokens (#33324)

* clarify error message when setting max_new_tokens

* sync error message in test_generate_with_prompt_ids_max_length

* there is no self
This commit is contained in:
benniekiss 2024-09-12 12:48:36 -04:00 committed by GitHub
parent 2f611d30d9
commit 5c6257d1fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -1698,8 +1698,8 @@ class WhisperGenerationMixin:
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, "
f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "

View File

@ -1349,8 +1349,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with self.assertRaisesRegex(
ValueError,
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, "
f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "