Added handling for length <2 of suppress_tokens for whisper (#36336)

* Update generation_whisper.py

Added handling for <2 length of suppress_tokens for whisper

* Updated None check for suppress_tokens to avoid ambiguity

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
andreystarenky 2025-02-25 06:32:49 -04:00 committed by GitHub
parent da4ab2a1b6
commit 3a02fe56c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1779,7 +1779,10 @@ class WhisperGenerationMixin(GenerationMixin):
prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
if prev_start_of_text is None:
prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
if suppress_tokens is not None and len(suppress_tokens) >= 2:
prev_start_of_text = suppress_tokens[-2]
else:
prev_start_of_text = None
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609