[Whisper] Refactor forced_decoder_ids & prompt ids (#28687)

* up

* Fix more

* Correct more

* Fix more tests

* fix fast tests

* Fix more

* fix more

* push all files

* finish all

* make style

* Fix timestamp wrap

* make style

* make style

* up

* up

* up

* Fix lang detection behavior

* Fix lang detection behavior

* Add lang detection test

* Fix lang detection behavior

* make style

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* better error message

* make style tests

* add warning

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2024-01-31 14:02:07 +02:00 committed by GitHub
parent f9f1f2ac5e
commit 65a926e82b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 605 additions and 228 deletions

View File

@ -16,7 +16,7 @@ import copy
import math
import warnings
import zlib
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@ -174,6 +174,8 @@ class WhisperGenerationMixin:
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
weight_length = None
if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
# since the beam search strategy chooses the most probable sequences at the end of the search.
@ -195,7 +197,9 @@ class WhisperGenerationMixin:
dim=2,
)
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
# make sure timestamps are as long as weights
input_length = weight_length or cross_attentions[0].shape[2]
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
batch_size = timestamps.shape[0]
if num_frames is not None:
@ -260,6 +264,7 @@ class WhisperGenerationMixin:
language: Optional[str] = None,
is_multilingual: Optional[bool] = None,
prompt_ids: Optional[torch.Tensor] = None,
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
condition_on_prev_tokens: Optional[bool] = None,
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
compression_ratio_threshold: Optional[float] = None,
@ -333,6 +338,9 @@ class WhisperGenerationMixin:
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
prompt_condition_type (`str`, *optional*):
Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
condition_on_prev_tokens (`bool`, *optional*):
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
@ -474,7 +482,7 @@ class WhisperGenerationMixin:
# 2. set global generate variables
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
num_segment_frames = input_stride * self.config.max_source_positions
total_input_frames = self._retrieve_total_input_frames(
batch_size, total_input_frames = self._retrieve_total_input_frames(
input_features=input_features, input_stride=input_stride, kwargs=kwargs
)
is_shortform = total_input_frames <= num_segment_frames
@ -505,15 +513,6 @@ class WhisperGenerationMixin:
self._set_language_and_task(
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
)
# pass self.config for backward compatibility
self._set_forced_decoder_ids(
task=task,
language=language,
prompt_ids=prompt_ids,
generation_config=generation_config,
config=self.config,
kwargs=kwargs,
)
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
self._set_num_frames(
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
@ -525,12 +524,31 @@ class WhisperGenerationMixin:
no_speech_threshold=no_speech_threshold,
condition_on_prev_tokens=condition_on_prev_tokens,
)
self._set_prompt_condition_type(
generation_config=generation_config,
prompt_condition_type=prompt_condition_type,
)
# 4. Retrieve logits processors
# pass self.config for backward compatibility
init_tokens = self._retrieve_init_tokens(
input_features,
generation_config=generation_config,
config=self.config,
num_segment_frames=num_segment_frames,
kwargs=kwargs,
)
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead
# This function should be be removed in v4.39
self._check_decoder_input_ids(
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
)
# 3. Retrieve logits processors
begin_index = len(init_tokens)
logits_processor = self._retrieve_logit_processors(
generation_config=generation_config,
logits_processor=logits_processor,
no_speech_threshold=no_speech_threshold,
begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform,
num_beams=kwargs.get("num_beams", 1),
)
@ -540,6 +558,27 @@ class WhisperGenerationMixin:
if temperature is not None:
kwargs["temperature"] = temperature
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
if decoder_input_ids is None:
one_tensor = torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
if prompt_ids is not None:
decoder_input_ids = torch.cat(
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
)
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
max_new_tokens = kwargs.get("max_new_tokens", 0)
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less than {self.config.max_target_positions}."
)
outputs = super().generate(
input_features,
generation_config=generation_config,
@ -547,6 +586,7 @@ class WhisperGenerationMixin:
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
**kwargs,
)
@ -573,11 +613,15 @@ class WhisperGenerationMixin:
max_frames, seek = self._retrieve_max_frames_and_seek(
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
)
init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config)
# 6.2 Preppare running variables, list for generation
cur_bsz = batch_size
current_segments = [[] for _ in range(batch_size)]
current_segments = self._prepare_segments(
prompt_ids=prompt_ids,
batch_size=batch_size,
generation_config=generation_config,
)
batch_idx_map = list(range(batch_size))
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
@ -617,6 +661,7 @@ class WhisperGenerationMixin:
current_segments=current_segments,
batch_idx_map=batch_idx_map,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
prompt_ids=prompt_ids,
generation_config=generation_config,
config=self.config,
device=segment_input.device,
@ -682,11 +727,16 @@ class WhisperGenerationMixin:
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right")
final_segments = (
[x[1:] for x in current_segments]
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
else current_segments
)
sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right")
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}
return {"sequences": sequences, "segments": final_segments}
return sequences
@ -721,7 +771,8 @@ class WhisperGenerationMixin:
for fallback_idx, temperature in enumerate(temperatures):
generation_config.do_sample = temperature is not None and temperature > 0.0
generation_config.temperature = temperature
generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
seek_outputs = super().generate(
@ -736,13 +787,13 @@ class WhisperGenerationMixin:
)
# post-process sequence tokens and outputs to be in list form
sequence_tokens, seek_outputs = self._postprocess_outputs(
seek_outputs, return_token_timestamps, generation_config
seek_sequences, seek_outputs = self._postprocess_outputs(
seek_outputs=seek_outputs,
decoder_input_ids=decoder_input_ids,
return_token_timestamps=return_token_timestamps,
generation_config=generation_config,
)
# remove all previously passed decoder input ids
seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :]
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
# Loop over each decoded audio individually as each decoding can be of a different length
new_fallback_index_map = []
@ -777,8 +828,9 @@ class WhisperGenerationMixin:
seek_sequence_list[fallback_index_map[i]] = seek_sequence
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
is_low_temperature = temperature is None or temperature < 0.5
do_condition_on_prev_tokens[fallback_index_map[i]] = (
generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5
generation_config.condition_on_prev_tokens and is_low_temperature
)
if needs_fallback[i]:
@ -804,30 +856,44 @@ class WhisperGenerationMixin:
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config):
@staticmethod
def _prepare_segments(prompt_ids, batch_size, generation_config):
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
else:
current_segments = [[] for _ in range(batch_size)]
return current_segments
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
# remove all previously passed decoder input ids
if isinstance(seek_outputs, torch.Tensor):
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :]
return seek_outputs, seek_outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
if generation_config.return_dict_in_generate:
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx):
if key == "scores":
return [v[batch_idx].cpu() for v in values]
if key == "past_key_values":
# we don't save `past_key_values` as this is too costly
return None
return values[batch_idx].cpu()
def split_by_batch_index(values, key, batch_idx):
if key == "scores":
return [v[batch_idx].cpu() for v in values]
if key == "past_key_values":
# we don't save `past_key_values` as this is too costly
return None
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
]
else:
sequence_tokens = seek_outputs
sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
]
return sequence_tokens, seek_outputs
@ -884,7 +950,7 @@ class WhisperGenerationMixin:
@staticmethod
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
if input_features is not None:
return input_features.shape[-1]
return input_features.shape[0], input_features.shape[-1]
if "encoder_outputs" in kwargs:
encoder_outputs_shape = (
@ -892,7 +958,7 @@ class WhisperGenerationMixin:
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
return encoder_outputs_shape[1] * input_stride
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
@ -950,34 +1016,24 @@ class WhisperGenerationMixin:
@staticmethod
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
if return_timestamps is True:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = True
elif not is_shortform:
if not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the generation config to have `no_timestamps_token_id` correctly. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
generation_config.return_timestamps = True
else:
generation_config.return_timestamps = False
return_timestamps = True
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
@staticmethod
def _set_language_and_task(language, task, is_multilingual, generation_config):
@ -1016,94 +1072,221 @@ class WhisperGenerationMixin:
)
generation_config.task = task
@staticmethod
def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs):
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
def _retrieve_init_tokens(self, input_features, generation_config, config, num_segment_frames, kwargs):
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
"""short function to replace num with a itr in lst"""
found = any(i in lst for i in itr)
if found:
lst = [num if i in itr else i for i in lst]
else:
lst.append(num)
return lst
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)
if kwargs.get("forced_decoder_ids", None) is not None:
forced_decoder_ids = kwargs["forced_decoder_ids"]
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
forced_decoder_ids = generation_config.forced_decoder_ids
else:
forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
forced_decoder_ids = []
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
elif generation_config.language in TO_LANGUAGE_CODE.values():
language_token = f"<|{generation_config.language}|>"
else:
is_language_code = len(generation_config.language) == 2
raise ValueError(
f"Unsupported language: {generation_config.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if language_token not in generation_config.lang_to_id:
raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
"(You should just add it to the generation config)"
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
elif hasattr(generation_config, "task_to_id"):
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if forced_decoder_ids is not None:
generation_config.forced_decoder_ids = forced_decoder_ids
if prompt_ids is not None:
if kwargs.get("decoder_start_token_id") is not None:
raise ValueError(
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :]
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
else:
forced_decoder_ids = None
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
if kwargs["max_new_tokens"] >= config.max_target_positions:
raise ValueError(
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
f"`max_target_positions` of the Whisper model: {config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {config.max_target_positions}."
)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
if forced_decoder_ids is not None and task is not None:
logger.info(
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
)
forced_decoder_ids = None
elif forced_decoder_ids is not None and language is not None:
logger.info(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None
init_tokens = [generation_config.decoder_start_token_id]
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
i = 1
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
init_tokens += [forced_decoder_ids[0][1]]
forced_decoder_ids = forced_decoder_ids[1:]
i += 1
# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
if len(forced_decoder_ids) > 0:
warnings.warn(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.",
FutureWarning,
)
# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
if language is not None:
if language in generation_config.lang_to_id.keys():
language_token = language
elif language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
elif language in TO_LANGUAGE_CODE.values():
language_token = f"<|{language}|>"
else:
is_language_code = len(language) == 2
raise ValueError(
f"Unsupported language: {language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if language_token not in generation_config.lang_to_id:
raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
"(You should just add it to the generation config)"
)
lang_id = generation_config.lang_to_id[language_token]
# if language is defined it'll overwrite language ids that might have already been defined via the generation_config
replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values())
elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
# language is not defined or intentially set to `None` to trigger language detection
lang_ids = self.detect_language(
input_features=input_features,
encoder_outputs=kwargs.get("encoder_outputs", None),
generation_config=generation_config,
num_segment_frames=num_segment_frames,
)
if torch.unique(lang_ids).shape[0] > 1:
raise ValueError(
"Multiple languages detected when trying to predict the most likely target language for transcription. It is currently not supported to transcribe to different languages in a single batch. Please make sure to either force a single language by passing `language='...'` or make sure all input audio is of the same language."
)
lang_id = lang_ids[0].item()
# append or replace lang_id to init_tokens
if len(init_tokens) > 1:
init_tokens[1] = lang_id
else:
init_tokens.append(lang_id)
if task is not None:
if task in TASK_IDS:
init_tokens.append(generation_config.task_to_id[generation_config.task])
task_id = generation_config.task_to_id[generation_config.task]
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
replace_or_add(init_tokens, task_id, generation_config.task_to_id.values())
else:
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
elif language is not None and hasattr(generation_config, "task_to_id"):
# if language is defined, but no task id is in `init_tokens`, default to transcribe
if not any(i in init_tokens for i in generation_config.task_to_id.values()):
init_tokens.append(generation_config.task_to_id["transcribe"])
if (
not generation_config.return_timestamps
and hasattr(generation_config, "no_timestamps_token_id")
and init_tokens[-1] != generation_config.no_timestamps_token_id
):
init_tokens.append(generation_config.no_timestamps_token_id)
elif generation_config.return_timestamps and init_tokens[-1] == generation_config.no_timestamps_token_id:
logger.info(
"<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
)
init_tokens = init_tokens[:-1]
# let's make sure we don't pass `None` tokens as prompt tokens
init_tokens = [t for t in init_tokens if t is not None]
return init_tokens
def detect_language(
self,
input_features: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
generation_config: Optional[GenerationConfig] = None,
num_segment_frames: int = 3000,
) -> torch.Tensor:
"""
Detects language from log-mel input features or encoder_outputs
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
num_segment_frames (`int`, defaults to 3000):
The number of log-mel frames the model expects
Return:
A `torch.LongTensor` representing the detected language ids.
"""
if input_features is None and encoder_outputs is None:
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
elif input_features is not None and encoder_outputs is not None:
raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
elif input_features is not None:
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
batch_size = input_features.shape[0]
elif encoder_outputs is not None:
inputs = {"encoder_outputs": encoder_outputs}
batch_size = (
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
)
generation_config = generation_config or self.generation_config
decoder_input_ids = (
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
* generation_config.decoder_start_token_id
)
with torch.no_grad():
logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
non_lang_mask[list(generation_config.lang_to_id.values())] = False
logits[:, non_lang_mask] = -np.inf
lang_ids = logits.argmax(-1)
return lang_ids
@staticmethod
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs):
decoder_input_ids = kwargs.get("decoder_input_ids", None)
if prompt_ids is not None and decoder_input_ids is not None:
raise ValueError(
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it."
)
elif decoder_input_ids is not None and not is_shortform:
raise ValueError(
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
)
elif decoder_input_ids is not None and is_shortform:
warnings.warn(
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
FutureWarning,
)
forced_decoder_ids = [
*text_prompt_ids,
generation_config.decoder_start_token_id,
*[token for _, token in non_prompt_forced_decoder_ids],
]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
@staticmethod
def _set_token_ids(generation_config, config, kwargs):
@ -1162,6 +1345,25 @@ class WhisperGenerationMixin:
else getattr(generation_config, "condition_on_prev_tokens", None)
)
@staticmethod
def _set_prompt_condition_type(generation_config, prompt_condition_type):
allowed_cond_types = ["first-segment", "all-segments"]
# default to "first-segment"
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
if prompt_condition_type not in allowed_cond_types:
raise ValueError(
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
)
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
raise ValueError(
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
)
generation_config.prompt_condition_type = prompt_condition_type
@staticmethod
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
condition_on_prev_tokens = (
@ -1175,7 +1377,7 @@ class WhisperGenerationMixin:
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames):
if batch_size > 1 and attention_mask is None:
raise ValueError(
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
elif batch_size > 1:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
@ -1186,37 +1388,7 @@ class WhisperGenerationMixin:
return max_frames, seek
@staticmethod
def _retrieve_init_tokens_from_forced_decoder_ids(generation_config):
init_tokens = [generation_config.decoder_start_token_id]
forced_decoder_ids = generation_config.forced_decoder_ids
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
i = 1
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
init_tokens += [forced_decoder_ids[0][1]]
forced_decoder_ids = forced_decoder_ids[1:]
i += 1
forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
generation_config.forced_decoder_ids = forced_decoder_ids
return init_tokens
def _retrieve_logit_processors(
self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams
):
forced_decoder_ids = generation_config.forced_decoder_ids
if generation_config.return_timestamps is True:
last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None
if last_forced_decoder_ids == generation_config.no_timestamps_token_id:
# remove no_timestamp to be forcefully generated if we want to return timestamps
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None
# Make sure that if list is empty we set it to None
generation_config.forced_decoder_ids = forced_decoder_ids
begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = (
@ -1243,7 +1415,7 @@ class WhisperGenerationMixin:
)
generation_config.begin_suppress_tokens = None
if no_speech_threshold is not None and not is_shortform:
if generation_config.no_speech_threshold is not None and not is_shortform:
no_speech_detector = WhisperNoSpeechDetection(
no_speech_token=generation_config.no_timestamps_token_id - 1,
begin_index=begin_index,
@ -1256,11 +1428,12 @@ class WhisperGenerationMixin:
if is_shortform and generation_config.forced_decoder_ids is not None:
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
# TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after
# It's important that the `forced_tokens_proc` processor is appended after
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
# which would lead to unexpected behavior
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
# initialize all of them as `decoder_input_ids`.
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
logits_processor = (
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
)
@ -1310,6 +1483,7 @@ class WhisperGenerationMixin:
current_segments,
batch_idx_map,
do_condition_on_prev_tokens,
prompt_ids,
generation_config,
config,
device,
@ -1328,19 +1502,27 @@ class WhisperGenerationMixin:
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text
bos_token_tensor = prev_start_of_text * one_tensor[0]
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
padding="left",
bos_token_tensor=bos_token_tensor,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
elif prompt_ids is not None:
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
# make sure `"decoder_attention_mask"` is not passed to forward
kwargs.pop("decoder_attention_mask", None)
else:
# make sure `"decoder_attention_mask"` is not passed to forward
kwargs.pop("decoder_attention_mask", None)

File diff suppressed because one or more lines are too long

View File

@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
# the audio is 4 seconds long
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
out = pipe(
@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
"chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}],
},
)