mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
[Whisper] Refactor forced_decoder_ids & prompt ids (#28687)
* up * Fix more * Correct more * Fix more tests * fix fast tests * Fix more * fix more * push all files * finish all * make style * Fix timestamp wrap * make style * make style * up * up * up * Fix lang detection behavior * Fix lang detection behavior * Add lang detection test * Fix lang detection behavior * make style * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * better error message * make style tests * add warning --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
f9f1f2ac5e
commit
65a926e82b
@ -16,7 +16,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -174,6 +174,8 @@ class WhisperGenerationMixin:
|
||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||
weights = weights.permute([1, 0, 2, 3])
|
||||
|
||||
weight_length = None
|
||||
|
||||
if "beam_indices" in generate_outputs:
|
||||
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
||||
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
||||
@ -195,7 +197,9 @@ class WhisperGenerationMixin:
|
||||
dim=2,
|
||||
)
|
||||
|
||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
||||
# make sure timestamps are as long as weights
|
||||
input_length = weight_length or cross_attentions[0].shape[2]
|
||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
|
||||
batch_size = timestamps.shape[0]
|
||||
|
||||
if num_frames is not None:
|
||||
@ -260,6 +264,7 @@ class WhisperGenerationMixin:
|
||||
language: Optional[str] = None,
|
||||
is_multilingual: Optional[bool] = None,
|
||||
prompt_ids: Optional[torch.Tensor] = None,
|
||||
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
|
||||
condition_on_prev_tokens: Optional[bool] = None,
|
||||
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
|
||||
compression_ratio_threshold: Optional[float] = None,
|
||||
@ -333,6 +338,9 @@ class WhisperGenerationMixin:
|
||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||
prompt_condition_type (`str`, *optional*):
|
||||
Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
|
||||
Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
|
||||
condition_on_prev_tokens (`bool`, *optional*):
|
||||
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
|
||||
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
||||
@ -474,7 +482,7 @@ class WhisperGenerationMixin:
|
||||
# 2. set global generate variables
|
||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
||||
num_segment_frames = input_stride * self.config.max_source_positions
|
||||
total_input_frames = self._retrieve_total_input_frames(
|
||||
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
||||
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
||||
)
|
||||
is_shortform = total_input_frames <= num_segment_frames
|
||||
@ -505,15 +513,6 @@ class WhisperGenerationMixin:
|
||||
self._set_language_and_task(
|
||||
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
||||
)
|
||||
# pass self.config for backward compatibility
|
||||
self._set_forced_decoder_ids(
|
||||
task=task,
|
||||
language=language,
|
||||
prompt_ids=prompt_ids,
|
||||
generation_config=generation_config,
|
||||
config=self.config,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
|
||||
self._set_num_frames(
|
||||
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
||||
@ -525,12 +524,31 @@ class WhisperGenerationMixin:
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
condition_on_prev_tokens=condition_on_prev_tokens,
|
||||
)
|
||||
self._set_prompt_condition_type(
|
||||
generation_config=generation_config,
|
||||
prompt_condition_type=prompt_condition_type,
|
||||
)
|
||||
|
||||
# 4. Retrieve logits processors
|
||||
# pass self.config for backward compatibility
|
||||
init_tokens = self._retrieve_init_tokens(
|
||||
input_features,
|
||||
generation_config=generation_config,
|
||||
config=self.config,
|
||||
num_segment_frames=num_segment_frames,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead
|
||||
# This function should be be removed in v4.39
|
||||
self._check_decoder_input_ids(
|
||||
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
|
||||
)
|
||||
|
||||
# 3. Retrieve logits processors
|
||||
begin_index = len(init_tokens)
|
||||
logits_processor = self._retrieve_logit_processors(
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||
is_shortform=is_shortform,
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
)
|
||||
@ -540,6 +558,27 @@ class WhisperGenerationMixin:
|
||||
if temperature is not None:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
|
||||
if decoder_input_ids is None:
|
||||
one_tensor = torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
|
||||
decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
|
||||
|
||||
if prompt_ids is not None:
|
||||
decoder_input_ids = torch.cat(
|
||||
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
|
||||
)
|
||||
|
||||
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
||||
max_new_tokens = kwargs.get("max_new_tokens", 0)
|
||||
raise ValueError(
|
||||
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
||||
f"is {max_new_tokens}. Thus, the combined length of "
|
||||
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
|
||||
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
|
||||
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
||||
f"so that their combined length is less than {self.config.max_target_positions}."
|
||||
)
|
||||
|
||||
outputs = super().generate(
|
||||
input_features,
|
||||
generation_config=generation_config,
|
||||
@ -547,6 +586,7 @@ class WhisperGenerationMixin:
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -573,11 +613,15 @@ class WhisperGenerationMixin:
|
||||
max_frames, seek = self._retrieve_max_frames_and_seek(
|
||||
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
|
||||
)
|
||||
init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config)
|
||||
|
||||
# 6.2 Preppare running variables, list for generation
|
||||
cur_bsz = batch_size
|
||||
current_segments = [[] for _ in range(batch_size)]
|
||||
current_segments = self._prepare_segments(
|
||||
prompt_ids=prompt_ids,
|
||||
batch_size=batch_size,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
batch_idx_map = list(range(batch_size))
|
||||
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
|
||||
|
||||
@ -617,6 +661,7 @@ class WhisperGenerationMixin:
|
||||
current_segments=current_segments,
|
||||
batch_idx_map=batch_idx_map,
|
||||
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
||||
prompt_ids=prompt_ids,
|
||||
generation_config=generation_config,
|
||||
config=self.config,
|
||||
device=segment_input.device,
|
||||
@ -682,11 +727,16 @@ class WhisperGenerationMixin:
|
||||
|
||||
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
||||
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
||||
sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right")
|
||||
final_segments = (
|
||||
[x[1:] for x in current_segments]
|
||||
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
|
||||
else current_segments
|
||||
)
|
||||
sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right")
|
||||
|
||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||
if return_segments:
|
||||
return {"sequences": sequences, "segments": current_segments}
|
||||
return {"sequences": sequences, "segments": final_segments}
|
||||
|
||||
return sequences
|
||||
|
||||
@ -721,7 +771,8 @@ class WhisperGenerationMixin:
|
||||
|
||||
for fallback_idx, temperature in enumerate(temperatures):
|
||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||
generation_config.temperature = temperature
|
||||
|
||||
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
|
||||
|
||||
seek_outputs = super().generate(
|
||||
@ -736,13 +787,13 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
|
||||
# post-process sequence tokens and outputs to be in list form
|
||||
sequence_tokens, seek_outputs = self._postprocess_outputs(
|
||||
seek_outputs, return_token_timestamps, generation_config
|
||||
seek_sequences, seek_outputs = self._postprocess_outputs(
|
||||
seek_outputs=seek_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
return_token_timestamps=return_token_timestamps,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
# remove all previously passed decoder input ids
|
||||
seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :]
|
||||
|
||||
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
|
||||
# Loop over each decoded audio individually as each decoding can be of a different length
|
||||
new_fallback_index_map = []
|
||||
@ -777,8 +828,9 @@ class WhisperGenerationMixin:
|
||||
|
||||
seek_sequence_list[fallback_index_map[i]] = seek_sequence
|
||||
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
|
||||
is_low_temperature = temperature is None or temperature < 0.5
|
||||
do_condition_on_prev_tokens[fallback_index_map[i]] = (
|
||||
generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5
|
||||
generation_config.condition_on_prev_tokens and is_low_temperature
|
||||
)
|
||||
|
||||
if needs_fallback[i]:
|
||||
@ -804,30 +856,44 @@ class WhisperGenerationMixin:
|
||||
|
||||
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
|
||||
|
||||
def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config):
|
||||
@staticmethod
|
||||
def _prepare_segments(prompt_ids, batch_size, generation_config):
|
||||
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
|
||||
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
|
||||
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
|
||||
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
|
||||
else:
|
||||
current_segments = [[] for _ in range(batch_size)]
|
||||
|
||||
return current_segments
|
||||
|
||||
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
|
||||
# remove all previously passed decoder input ids
|
||||
if isinstance(seek_outputs, torch.Tensor):
|
||||
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :]
|
||||
return seek_outputs, seek_outputs
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
num_frames = getattr(generation_config, "num_frames", None)
|
||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||
)
|
||||
|
||||
if generation_config.return_dict_in_generate:
|
||||
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
|
||||
|
||||
def split_by_batch_index(values, key, batch_idx):
|
||||
if key == "scores":
|
||||
return [v[batch_idx].cpu() for v in values]
|
||||
if key == "past_key_values":
|
||||
# we don't save `past_key_values` as this is too costly
|
||||
return None
|
||||
return values[batch_idx].cpu()
|
||||
def split_by_batch_index(values, key, batch_idx):
|
||||
if key == "scores":
|
||||
return [v[batch_idx].cpu() for v in values]
|
||||
if key == "past_key_values":
|
||||
# we don't save `past_key_values` as this is too costly
|
||||
return None
|
||||
return values[batch_idx].cpu()
|
||||
|
||||
sequence_tokens = seek_outputs["sequences"]
|
||||
seek_outputs = [
|
||||
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
|
||||
for i in range(sequence_tokens.shape[0])
|
||||
]
|
||||
else:
|
||||
sequence_tokens = seek_outputs
|
||||
sequence_tokens = seek_outputs["sequences"]
|
||||
seek_outputs = [
|
||||
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
|
||||
for i in range(sequence_tokens.shape[0])
|
||||
]
|
||||
|
||||
return sequence_tokens, seek_outputs
|
||||
|
||||
@ -884,7 +950,7 @@ class WhisperGenerationMixin:
|
||||
@staticmethod
|
||||
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
|
||||
if input_features is not None:
|
||||
return input_features.shape[-1]
|
||||
return input_features.shape[0], input_features.shape[-1]
|
||||
|
||||
if "encoder_outputs" in kwargs:
|
||||
encoder_outputs_shape = (
|
||||
@ -892,7 +958,7 @@ class WhisperGenerationMixin:
|
||||
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
||||
else kwargs["encoder_outputs"].shape
|
||||
)
|
||||
return encoder_outputs_shape[1] * input_stride
|
||||
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
|
||||
|
||||
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
||||
|
||||
@ -950,34 +1016,24 @@ class WhisperGenerationMixin:
|
||||
|
||||
@staticmethod
|
||||
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
|
||||
if return_timestamps is True:
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
)
|
||||
generation_config.return_timestamps = True
|
||||
elif not is_shortform:
|
||||
if not is_shortform:
|
||||
if return_timestamps is False:
|
||||
raise ValueError(
|
||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
||||
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
||||
)
|
||||
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
||||
"requires the generation config to have `no_timestamps_token_id` correctly. "
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
"or make sure to pass no more than 3000 mel input features."
|
||||
)
|
||||
|
||||
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
||||
generation_config.return_timestamps = True
|
||||
else:
|
||||
generation_config.return_timestamps = False
|
||||
return_timestamps = True
|
||||
|
||||
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
)
|
||||
|
||||
generation_config.return_timestamps = return_timestamps
|
||||
|
||||
@staticmethod
|
||||
def _set_language_and_task(language, task, is_multilingual, generation_config):
|
||||
@ -1016,94 +1072,221 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
generation_config.task = task
|
||||
|
||||
@staticmethod
|
||||
def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs):
|
||||
forced_decoder_ids = None
|
||||
# Legacy code for backward compatibility
|
||||
if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = config.forced_decoder_ids
|
||||
def _retrieve_init_tokens(self, input_features, generation_config, config, num_segment_frames, kwargs):
|
||||
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
|
||||
"""short function to replace num with a itr in lst"""
|
||||
found = any(i in lst for i in itr)
|
||||
if found:
|
||||
lst = [num if i in itr else i for i in lst]
|
||||
else:
|
||||
lst.append(num)
|
||||
return lst
|
||||
|
||||
task = getattr(generation_config, "task", None)
|
||||
language = getattr(generation_config, "language", None)
|
||||
|
||||
if kwargs.get("forced_decoder_ids", None) is not None:
|
||||
forced_decoder_ids = kwargs["forced_decoder_ids"]
|
||||
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||
else:
|
||||
forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||
|
||||
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
||||
forced_decoder_ids = []
|
||||
if hasattr(generation_config, "language"):
|
||||
if generation_config.language in generation_config.lang_to_id.keys():
|
||||
language_token = generation_config.language
|
||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
||||
elif generation_config.language in TO_LANGUAGE_CODE.values():
|
||||
language_token = f"<|{generation_config.language}|>"
|
||||
else:
|
||||
is_language_code = len(generation_config.language) == 2
|
||||
raise ValueError(
|
||||
f"Unsupported language: {generation_config.language}. Language should be one of:"
|
||||
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
||||
)
|
||||
if language_token not in generation_config.lang_to_id:
|
||||
raise ValueError(
|
||||
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
||||
"(You should just add it to the generation config)"
|
||||
)
|
||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
||||
else:
|
||||
forced_decoder_ids.append((1, None)) # automatically detect the language
|
||||
|
||||
if hasattr(generation_config, "task"):
|
||||
if generation_config.task in TASK_IDS:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
||||
)
|
||||
elif hasattr(generation_config, "task_to_id"):
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||
|
||||
if forced_decoder_ids is not None:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
if prompt_ids is not None:
|
||||
if kwargs.get("decoder_start_token_id") is not None:
|
||||
raise ValueError(
|
||||
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
||||
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
||||
logger.warning_once(
|
||||
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
||||
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
|
||||
)
|
||||
prompt_ids = prompt_ids.tolist()
|
||||
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
||||
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
||||
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
||||
text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :]
|
||||
# Set the decoder_start_token_id to <|startofprev|>
|
||||
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
||||
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = config.forced_decoder_ids
|
||||
else:
|
||||
forced_decoder_ids = None
|
||||
|
||||
# If the user passes `max_new_tokens`, increase its number to account for the prompt
|
||||
if kwargs.get("max_new_tokens", None) is not None:
|
||||
kwargs["max_new_tokens"] += len(text_prompt_ids)
|
||||
if kwargs["max_new_tokens"] >= config.max_target_positions:
|
||||
raise ValueError(
|
||||
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
|
||||
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
|
||||
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
|
||||
f"`max_target_positions` of the Whisper model: {config.max_target_positions}. "
|
||||
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
||||
f"so that their combined length is less that {config.max_target_positions}."
|
||||
)
|
||||
|
||||
# Reformat the forced_decoder_ids to incorporate the prompt
|
||||
non_prompt_forced_decoder_ids = (
|
||||
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
||||
if forced_decoder_ids is not None and task is not None:
|
||||
logger.info(
|
||||
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
|
||||
)
|
||||
forced_decoder_ids = None
|
||||
elif forced_decoder_ids is not None and language is not None:
|
||||
logger.info(
|
||||
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
|
||||
)
|
||||
forced_decoder_ids = None
|
||||
|
||||
init_tokens = [generation_config.decoder_start_token_id]
|
||||
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
||||
i = 1
|
||||
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
||||
init_tokens += [forced_decoder_ids[0][1]]
|
||||
forced_decoder_ids = forced_decoder_ids[1:]
|
||||
i += 1
|
||||
|
||||
# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
|
||||
if len(forced_decoder_ids) > 0:
|
||||
warnings.warn(
|
||||
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
|
||||
|
||||
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
||||
if language is not None:
|
||||
if language in generation_config.lang_to_id.keys():
|
||||
language_token = language
|
||||
elif language in TO_LANGUAGE_CODE.keys():
|
||||
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
|
||||
elif language in TO_LANGUAGE_CODE.values():
|
||||
language_token = f"<|{language}|>"
|
||||
else:
|
||||
is_language_code = len(language) == 2
|
||||
raise ValueError(
|
||||
f"Unsupported language: {language}. Language should be one of:"
|
||||
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
||||
)
|
||||
if language_token not in generation_config.lang_to_id:
|
||||
raise ValueError(
|
||||
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
||||
"(You should just add it to the generation config)"
|
||||
)
|
||||
|
||||
lang_id = generation_config.lang_to_id[language_token]
|
||||
|
||||
# if language is defined it'll overwrite language ids that might have already been defined via the generation_config
|
||||
replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values())
|
||||
elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
|
||||
# language is not defined or intentially set to `None` to trigger language detection
|
||||
lang_ids = self.detect_language(
|
||||
input_features=input_features,
|
||||
encoder_outputs=kwargs.get("encoder_outputs", None),
|
||||
generation_config=generation_config,
|
||||
num_segment_frames=num_segment_frames,
|
||||
)
|
||||
|
||||
if torch.unique(lang_ids).shape[0] > 1:
|
||||
raise ValueError(
|
||||
"Multiple languages detected when trying to predict the most likely target language for transcription. It is currently not supported to transcribe to different languages in a single batch. Please make sure to either force a single language by passing `language='...'` or make sure all input audio is of the same language."
|
||||
)
|
||||
|
||||
lang_id = lang_ids[0].item()
|
||||
|
||||
# append or replace lang_id to init_tokens
|
||||
if len(init_tokens) > 1:
|
||||
init_tokens[1] = lang_id
|
||||
else:
|
||||
init_tokens.append(lang_id)
|
||||
|
||||
if task is not None:
|
||||
if task in TASK_IDS:
|
||||
init_tokens.append(generation_config.task_to_id[generation_config.task])
|
||||
task_id = generation_config.task_to_id[generation_config.task]
|
||||
|
||||
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
|
||||
replace_or_add(init_tokens, task_id, generation_config.task_to_id.values())
|
||||
else:
|
||||
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
|
||||
elif language is not None and hasattr(generation_config, "task_to_id"):
|
||||
# if language is defined, but no task id is in `init_tokens`, default to transcribe
|
||||
if not any(i in init_tokens for i in generation_config.task_to_id.values()):
|
||||
init_tokens.append(generation_config.task_to_id["transcribe"])
|
||||
|
||||
if (
|
||||
not generation_config.return_timestamps
|
||||
and hasattr(generation_config, "no_timestamps_token_id")
|
||||
and init_tokens[-1] != generation_config.no_timestamps_token_id
|
||||
):
|
||||
init_tokens.append(generation_config.no_timestamps_token_id)
|
||||
elif generation_config.return_timestamps and init_tokens[-1] == generation_config.no_timestamps_token_id:
|
||||
logger.info(
|
||||
"<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
|
||||
)
|
||||
init_tokens = init_tokens[:-1]
|
||||
|
||||
# let's make sure we don't pass `None` tokens as prompt tokens
|
||||
init_tokens = [t for t in init_tokens if t is not None]
|
||||
|
||||
return init_tokens
|
||||
|
||||
def detect_language(
|
||||
self,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_segment_frames: int = 3000,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Detects language from log-mel input features or encoder_outputs
|
||||
|
||||
Parameters:
|
||||
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
|
||||
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
|
||||
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
||||
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
||||
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
||||
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which had the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
num_segment_frames (`int`, defaults to 3000):
|
||||
The number of log-mel frames the model expects
|
||||
|
||||
Return:
|
||||
A `torch.LongTensor` representing the detected language ids.
|
||||
"""
|
||||
if input_features is None and encoder_outputs is None:
|
||||
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
|
||||
elif input_features is not None and encoder_outputs is not None:
|
||||
raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
|
||||
elif input_features is not None:
|
||||
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
|
||||
batch_size = input_features.shape[0]
|
||||
elif encoder_outputs is not None:
|
||||
inputs = {"encoder_outputs": encoder_outputs}
|
||||
batch_size = (
|
||||
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
|
||||
)
|
||||
|
||||
generation_config = generation_config or self.generation_config
|
||||
decoder_input_ids = (
|
||||
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
|
||||
* generation_config.decoder_start_token_id
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
|
||||
|
||||
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
|
||||
non_lang_mask[list(generation_config.lang_to_id.values())] = False
|
||||
|
||||
logits[:, non_lang_mask] = -np.inf
|
||||
|
||||
lang_ids = logits.argmax(-1)
|
||||
|
||||
return lang_ids
|
||||
|
||||
@staticmethod
|
||||
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs):
|
||||
decoder_input_ids = kwargs.get("decoder_input_ids", None)
|
||||
if prompt_ids is not None and decoder_input_ids is not None:
|
||||
raise ValueError(
|
||||
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it."
|
||||
)
|
||||
elif decoder_input_ids is not None and not is_shortform:
|
||||
raise ValueError(
|
||||
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
|
||||
)
|
||||
elif decoder_input_ids is not None and is_shortform:
|
||||
warnings.warn(
|
||||
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
forced_decoder_ids = [
|
||||
*text_prompt_ids,
|
||||
generation_config.decoder_start_token_id,
|
||||
*[token for _, token in non_prompt_forced_decoder_ids],
|
||||
]
|
||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
@staticmethod
|
||||
def _set_token_ids(generation_config, config, kwargs):
|
||||
@ -1162,6 +1345,25 @@ class WhisperGenerationMixin:
|
||||
else getattr(generation_config, "condition_on_prev_tokens", None)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _set_prompt_condition_type(generation_config, prompt_condition_type):
|
||||
allowed_cond_types = ["first-segment", "all-segments"]
|
||||
|
||||
# default to "first-segment"
|
||||
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
|
||||
|
||||
if prompt_condition_type not in allowed_cond_types:
|
||||
raise ValueError(
|
||||
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
|
||||
)
|
||||
|
||||
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
|
||||
raise ValueError(
|
||||
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
|
||||
)
|
||||
|
||||
generation_config.prompt_condition_type = prompt_condition_type
|
||||
|
||||
@staticmethod
|
||||
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
|
||||
condition_on_prev_tokens = (
|
||||
@ -1175,7 +1377,7 @@ class WhisperGenerationMixin:
|
||||
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames):
|
||||
if batch_size > 1 and attention_mask is None:
|
||||
raise ValueError(
|
||||
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
||||
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
||||
)
|
||||
elif batch_size > 1:
|
||||
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
||||
@ -1186,37 +1388,7 @@ class WhisperGenerationMixin:
|
||||
|
||||
return max_frames, seek
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_init_tokens_from_forced_decoder_ids(generation_config):
|
||||
init_tokens = [generation_config.decoder_start_token_id]
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
||||
i = 1
|
||||
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
||||
init_tokens += [forced_decoder_ids[0][1]]
|
||||
forced_decoder_ids = forced_decoder_ids[1:]
|
||||
i += 1
|
||||
|
||||
forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
return init_tokens
|
||||
|
||||
def _retrieve_logit_processors(
|
||||
self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams
|
||||
):
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||
if generation_config.return_timestamps is True:
|
||||
last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None
|
||||
if last_forced_decoder_ids == generation_config.no_timestamps_token_id:
|
||||
# remove no_timestamp to be forcefully generated if we want to return timestamps
|
||||
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
|
||||
forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None
|
||||
# Make sure that if list is empty we set it to None
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1
|
||||
|
||||
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
|
||||
if generation_config.return_timestamps is True:
|
||||
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
|
||||
logits_processor = (
|
||||
@ -1243,7 +1415,7 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
generation_config.begin_suppress_tokens = None
|
||||
|
||||
if no_speech_threshold is not None and not is_shortform:
|
||||
if generation_config.no_speech_threshold is not None and not is_shortform:
|
||||
no_speech_detector = WhisperNoSpeechDetection(
|
||||
no_speech_token=generation_config.no_timestamps_token_id - 1,
|
||||
begin_index=begin_index,
|
||||
@ -1256,11 +1428,12 @@ class WhisperGenerationMixin:
|
||||
|
||||
if is_shortform and generation_config.forced_decoder_ids is not None:
|
||||
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
|
||||
# TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after
|
||||
# It's important that the `forced_tokens_proc` processor is appended after
|
||||
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
|
||||
# which would lead to unexpected behavior
|
||||
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
|
||||
# initialize all of them as `decoder_input_ids`.
|
||||
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
|
||||
logits_processor = (
|
||||
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
|
||||
)
|
||||
@ -1310,6 +1483,7 @@ class WhisperGenerationMixin:
|
||||
current_segments,
|
||||
batch_idx_map,
|
||||
do_condition_on_prev_tokens,
|
||||
prompt_ids,
|
||||
generation_config,
|
||||
config,
|
||||
device,
|
||||
@ -1328,19 +1502,27 @@ class WhisperGenerationMixin:
|
||||
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
|
||||
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
||||
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
||||
prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text
|
||||
|
||||
bos_token_tensor = prev_start_of_text * one_tensor[0]
|
||||
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
||||
prev_ids = prompt_ids
|
||||
else:
|
||||
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
|
||||
|
||||
prev_tokens = _pad_to_max_length(
|
||||
active_segments,
|
||||
generation_config.pad_token_id,
|
||||
padding="left",
|
||||
bos_token_tensor=bos_token_tensor,
|
||||
bos_token_tensor=prev_ids,
|
||||
cut_off_length=cut_off_length,
|
||||
)
|
||||
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
||||
|
||||
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
|
||||
elif prompt_ids is not None:
|
||||
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
|
||||
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
||||
# make sure `"decoder_attention_mask"` is not passed to forward
|
||||
kwargs.pop("decoder_attention_mask", None)
|
||||
else:
|
||||
# make sure `"decoder_attention_mask"` is not passed to forward
|
||||
kwargs.pop("decoder_attention_mask", None)
|
||||
|
File diff suppressed because one or more lines are too long
@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
# Original model wasn't trained with timestamps and has incorrect generation config
|
||||
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||
|
||||
# the audio is 4 seconds long
|
||||
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||
|
||||
out = pipe(
|
||||
@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
out,
|
||||
{
|
||||
"chunks": [
|
||||
{"text": "", "timestamp": (18.94, 0.02)},
|
||||
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
|
||||
],
|
||||
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
|
||||
"chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user