[🛠️] Fix-whisper-breaking-changes (#21965)

* temp fix

* temporary fix

* update

* fix tests

* fixup

* update based on reveiew

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* update to fix tests

* update docstring

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Arthur 2023-03-14 09:23:48 +01:00 committed by GitHub
parent 101a6cd276
commit 2beabd24f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch Whisper model."""
import math
import random
from typing import Optional, Tuple, Union
@ -37,6 +36,7 @@ from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
logger = logging.get_logger(__name__)
@ -1510,8 +1510,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
language tokens in the `model.generation_config.lang_to_id` dictionary.
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
kwargs:
@ -1543,39 +1543,63 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
generation_config = self.generation_config
if return_timestamps is not None:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set."
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
else:
generation_config.return_timestamps = False
if language is not None:
generation_config.language = language
if task is not None:
generation_config.task = task
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if task is not None or language is not None:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None))
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids