mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[🛠️] Fix-whisper-breaking-changes (#21965)
* temp fix * temporary fix * update * fix tests * fixup * update based on reveiew Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * update to fix tests * update docstring --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
101a6cd276
commit
2beabd24f0
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Whisper model."""
|
||||
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
@ -37,6 +36,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_whisper import WhisperConfig
|
||||
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -1510,8 +1510,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
||||
will be updated accordingly.
|
||||
language (`bool`, *optional*):
|
||||
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
|
||||
language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
|
||||
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||
is_multilingual (`bool`, *optional*):
|
||||
Whether or not the model is multilingual.
|
||||
kwargs:
|
||||
@ -1543,39 +1543,63 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
generation_config = self.generation_config
|
||||
|
||||
if return_timestamps is not None:
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You are trying to return timestamps, but the generation config is not properly set."
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
)
|
||||
|
||||
generation_config.return_timestamps = return_timestamps
|
||||
|
||||
if task is not None:
|
||||
generation_config.task = task
|
||||
|
||||
if is_multilingual is not None:
|
||||
generation_config.is_multilingual = is_multilingual
|
||||
else:
|
||||
generation_config.return_timestamps = False
|
||||
|
||||
if language is not None:
|
||||
generation_config.language = language
|
||||
if task is not None:
|
||||
generation_config.task = task
|
||||
|
||||
forced_decoder_ids = []
|
||||
|
||||
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
|
||||
if task is not None or language is not None:
|
||||
if hasattr(generation_config, "language"):
|
||||
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
|
||||
if generation_config.language in generation_config.lang_to_id.keys():
|
||||
language_token = generation_config.language
|
||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {self.language}. Language should be one of:"
|
||||
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
|
||||
)
|
||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
||||
else:
|
||||
forced_decoder_ids.append((1, None))
|
||||
forced_decoder_ids.append((1, None)) # automatically detect the language
|
||||
|
||||
if hasattr(generation_config, "task"):
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
||||
if generation_config.task in TASK_IDS:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
||||
)
|
||||
else:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
||||
|
||||
if (
|
||||
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
|
||||
) or return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
else:
|
||||
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||
|
||||
# Legacy code for backward compatibility
|
||||
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
elif (
|
||||
hasattr(self.generation_config, "forced_decoder_ids")
|
||||
and self.generation_config.forced_decoder_ids is not None
|
||||
):
|
||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
||||
|
||||
if generation_config.return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
|
||||
if len(forced_decoder_ids) > 0:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user