mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Whisper] Refactor forced_decoder_ids & prompt ids (#28687)
* up * Fix more * Correct more * Fix more tests * fix fast tests * Fix more * fix more * push all files * finish all * make style * Fix timestamp wrap * make style * make style * up * up * up * Fix lang detection behavior * Fix lang detection behavior * Add lang detection test * Fix lang detection behavior * make style * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * better error message * make style tests * add warning --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
f9f1f2ac5e
commit
65a926e82b
@ -16,7 +16,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -174,6 +174,8 @@ class WhisperGenerationMixin:
|
|||||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||||
weights = weights.permute([1, 0, 2, 3])
|
weights = weights.permute([1, 0, 2, 3])
|
||||||
|
|
||||||
|
weight_length = None
|
||||||
|
|
||||||
if "beam_indices" in generate_outputs:
|
if "beam_indices" in generate_outputs:
|
||||||
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
||||||
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
||||||
@ -195,7 +197,9 @@ class WhisperGenerationMixin:
|
|||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
# make sure timestamps are as long as weights
|
||||||
|
input_length = weight_length or cross_attentions[0].shape[2]
|
||||||
|
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
|
||||||
batch_size = timestamps.shape[0]
|
batch_size = timestamps.shape[0]
|
||||||
|
|
||||||
if num_frames is not None:
|
if num_frames is not None:
|
||||||
@ -260,6 +264,7 @@ class WhisperGenerationMixin:
|
|||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
is_multilingual: Optional[bool] = None,
|
is_multilingual: Optional[bool] = None,
|
||||||
prompt_ids: Optional[torch.Tensor] = None,
|
prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
|
||||||
condition_on_prev_tokens: Optional[bool] = None,
|
condition_on_prev_tokens: Optional[bool] = None,
|
||||||
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
|
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
|
||||||
compression_ratio_threshold: Optional[float] = None,
|
compression_ratio_threshold: Optional[float] = None,
|
||||||
@ -333,6 +338,9 @@ class WhisperGenerationMixin:
|
|||||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||||
|
prompt_condition_type (`str`, *optional*):
|
||||||
|
Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
|
||||||
|
Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
|
||||||
condition_on_prev_tokens (`bool`, *optional*):
|
condition_on_prev_tokens (`bool`, *optional*):
|
||||||
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
|
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
|
||||||
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
|
||||||
@ -474,7 +482,7 @@ class WhisperGenerationMixin:
|
|||||||
# 2. set global generate variables
|
# 2. set global generate variables
|
||||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
||||||
num_segment_frames = input_stride * self.config.max_source_positions
|
num_segment_frames = input_stride * self.config.max_source_positions
|
||||||
total_input_frames = self._retrieve_total_input_frames(
|
batch_size, total_input_frames = self._retrieve_total_input_frames(
|
||||||
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
input_features=input_features, input_stride=input_stride, kwargs=kwargs
|
||||||
)
|
)
|
||||||
is_shortform = total_input_frames <= num_segment_frames
|
is_shortform = total_input_frames <= num_segment_frames
|
||||||
@ -505,15 +513,6 @@ class WhisperGenerationMixin:
|
|||||||
self._set_language_and_task(
|
self._set_language_and_task(
|
||||||
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
||||||
)
|
)
|
||||||
# pass self.config for backward compatibility
|
|
||||||
self._set_forced_decoder_ids(
|
|
||||||
task=task,
|
|
||||||
language=language,
|
|
||||||
prompt_ids=prompt_ids,
|
|
||||||
generation_config=generation_config,
|
|
||||||
config=self.config,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
|
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
|
||||||
self._set_num_frames(
|
self._set_num_frames(
|
||||||
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
||||||
@ -525,12 +524,31 @@ class WhisperGenerationMixin:
|
|||||||
no_speech_threshold=no_speech_threshold,
|
no_speech_threshold=no_speech_threshold,
|
||||||
condition_on_prev_tokens=condition_on_prev_tokens,
|
condition_on_prev_tokens=condition_on_prev_tokens,
|
||||||
)
|
)
|
||||||
|
self._set_prompt_condition_type(
|
||||||
|
generation_config=generation_config,
|
||||||
|
prompt_condition_type=prompt_condition_type,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Retrieve logits processors
|
# pass self.config for backward compatibility
|
||||||
|
init_tokens = self._retrieve_init_tokens(
|
||||||
|
input_features,
|
||||||
|
generation_config=generation_config,
|
||||||
|
config=self.config,
|
||||||
|
num_segment_frames=num_segment_frames,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
# TODO(Sanchit) - passing `decoder_input_ids` is deprecated. One should use `prompt_ids` instead
|
||||||
|
# This function should be be removed in v4.39
|
||||||
|
self._check_decoder_input_ids(
|
||||||
|
prompt_ids=prompt_ids, init_tokens=init_tokens, is_shortform=is_shortform, kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Retrieve logits processors
|
||||||
|
begin_index = len(init_tokens)
|
||||||
logits_processor = self._retrieve_logit_processors(
|
logits_processor = self._retrieve_logit_processors(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
no_speech_threshold=no_speech_threshold,
|
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||||
is_shortform=is_shortform,
|
is_shortform=is_shortform,
|
||||||
num_beams=kwargs.get("num_beams", 1),
|
num_beams=kwargs.get("num_beams", 1),
|
||||||
)
|
)
|
||||||
@ -540,6 +558,27 @@ class WhisperGenerationMixin:
|
|||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
kwargs["temperature"] = temperature
|
kwargs["temperature"] = temperature
|
||||||
|
|
||||||
|
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
one_tensor = torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
|
||||||
|
decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
|
||||||
|
|
||||||
|
if prompt_ids is not None:
|
||||||
|
decoder_input_ids = torch.cat(
|
||||||
|
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
|
||||||
|
max_new_tokens = kwargs.get("max_new_tokens", 0)
|
||||||
|
raise ValueError(
|
||||||
|
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
|
||||||
|
f"is {max_new_tokens}. Thus, the combined length of "
|
||||||
|
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
|
||||||
|
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
|
||||||
|
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
||||||
|
f"so that their combined length is less than {self.config.max_target_positions}."
|
||||||
|
)
|
||||||
|
|
||||||
outputs = super().generate(
|
outputs = super().generate(
|
||||||
input_features,
|
input_features,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
@ -547,6 +586,7 @@ class WhisperGenerationMixin:
|
|||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -573,11 +613,15 @@ class WhisperGenerationMixin:
|
|||||||
max_frames, seek = self._retrieve_max_frames_and_seek(
|
max_frames, seek = self._retrieve_max_frames_and_seek(
|
||||||
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
|
batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
|
||||||
)
|
)
|
||||||
init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config)
|
|
||||||
|
|
||||||
# 6.2 Preppare running variables, list for generation
|
# 6.2 Preppare running variables, list for generation
|
||||||
cur_bsz = batch_size
|
cur_bsz = batch_size
|
||||||
current_segments = [[] for _ in range(batch_size)]
|
current_segments = self._prepare_segments(
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
batch_size=batch_size,
|
||||||
|
generation_config=generation_config,
|
||||||
|
)
|
||||||
|
|
||||||
batch_idx_map = list(range(batch_size))
|
batch_idx_map = list(range(batch_size))
|
||||||
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
|
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
|
||||||
|
|
||||||
@ -617,6 +661,7 @@ class WhisperGenerationMixin:
|
|||||||
current_segments=current_segments,
|
current_segments=current_segments,
|
||||||
batch_idx_map=batch_idx_map,
|
batch_idx_map=batch_idx_map,
|
||||||
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
device=segment_input.device,
|
device=segment_input.device,
|
||||||
@ -682,11 +727,16 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
||||||
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
||||||
sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right")
|
final_segments = (
|
||||||
|
[x[1:] for x in current_segments]
|
||||||
|
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
|
||||||
|
else current_segments
|
||||||
|
)
|
||||||
|
sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right")
|
||||||
|
|
||||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||||
if return_segments:
|
if return_segments:
|
||||||
return {"sequences": sequences, "segments": current_segments}
|
return {"sequences": sequences, "segments": final_segments}
|
||||||
|
|
||||||
return sequences
|
return sequences
|
||||||
|
|
||||||
@ -721,7 +771,8 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
for fallback_idx, temperature in enumerate(temperatures):
|
for fallback_idx, temperature in enumerate(temperatures):
|
||||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||||
generation_config.temperature = temperature
|
|
||||||
|
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||||
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
|
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
|
||||||
|
|
||||||
seek_outputs = super().generate(
|
seek_outputs = super().generate(
|
||||||
@ -736,13 +787,13 @@ class WhisperGenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# post-process sequence tokens and outputs to be in list form
|
# post-process sequence tokens and outputs to be in list form
|
||||||
sequence_tokens, seek_outputs = self._postprocess_outputs(
|
seek_sequences, seek_outputs = self._postprocess_outputs(
|
||||||
seek_outputs, return_token_timestamps, generation_config
|
seek_outputs=seek_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
return_token_timestamps=return_token_timestamps,
|
||||||
|
generation_config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove all previously passed decoder input ids
|
|
||||||
seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :]
|
|
||||||
|
|
||||||
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
|
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
|
||||||
# Loop over each decoded audio individually as each decoding can be of a different length
|
# Loop over each decoded audio individually as each decoding can be of a different length
|
||||||
new_fallback_index_map = []
|
new_fallback_index_map = []
|
||||||
@ -777,8 +828,9 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
seek_sequence_list[fallback_index_map[i]] = seek_sequence
|
seek_sequence_list[fallback_index_map[i]] = seek_sequence
|
||||||
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
|
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
|
||||||
|
is_low_temperature = temperature is None or temperature < 0.5
|
||||||
do_condition_on_prev_tokens[fallback_index_map[i]] = (
|
do_condition_on_prev_tokens[fallback_index_map[i]] = (
|
||||||
generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5
|
generation_config.condition_on_prev_tokens and is_low_temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
if needs_fallback[i]:
|
if needs_fallback[i]:
|
||||||
@ -804,30 +856,44 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
|
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
|
||||||
|
|
||||||
def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config):
|
@staticmethod
|
||||||
|
def _prepare_segments(prompt_ids, batch_size, generation_config):
|
||||||
|
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
|
||||||
|
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
|
||||||
|
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
|
||||||
|
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
|
||||||
|
else:
|
||||||
|
current_segments = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
return current_segments
|
||||||
|
|
||||||
|
def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
|
||||||
|
# remove all previously passed decoder input ids
|
||||||
|
if isinstance(seek_outputs, torch.Tensor):
|
||||||
|
seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :]
|
||||||
|
return seek_outputs, seek_outputs
|
||||||
|
|
||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||||
num_frames = getattr(generation_config, "num_frames", None)
|
num_frames = getattr(generation_config, "num_frames", None)
|
||||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.return_dict_in_generate:
|
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
|
||||||
|
|
||||||
def split_by_batch_index(values, key, batch_idx):
|
def split_by_batch_index(values, key, batch_idx):
|
||||||
if key == "scores":
|
if key == "scores":
|
||||||
return [v[batch_idx].cpu() for v in values]
|
return [v[batch_idx].cpu() for v in values]
|
||||||
if key == "past_key_values":
|
if key == "past_key_values":
|
||||||
# we don't save `past_key_values` as this is too costly
|
# we don't save `past_key_values` as this is too costly
|
||||||
return None
|
return None
|
||||||
return values[batch_idx].cpu()
|
return values[batch_idx].cpu()
|
||||||
|
|
||||||
sequence_tokens = seek_outputs["sequences"]
|
sequence_tokens = seek_outputs["sequences"]
|
||||||
seek_outputs = [
|
seek_outputs = [
|
||||||
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
|
{k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
|
||||||
for i in range(sequence_tokens.shape[0])
|
for i in range(sequence_tokens.shape[0])
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
sequence_tokens = seek_outputs
|
|
||||||
|
|
||||||
return sequence_tokens, seek_outputs
|
return sequence_tokens, seek_outputs
|
||||||
|
|
||||||
@ -884,7 +950,7 @@ class WhisperGenerationMixin:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
|
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
|
||||||
if input_features is not None:
|
if input_features is not None:
|
||||||
return input_features.shape[-1]
|
return input_features.shape[0], input_features.shape[-1]
|
||||||
|
|
||||||
if "encoder_outputs" in kwargs:
|
if "encoder_outputs" in kwargs:
|
||||||
encoder_outputs_shape = (
|
encoder_outputs_shape = (
|
||||||
@ -892,7 +958,7 @@ class WhisperGenerationMixin:
|
|||||||
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
||||||
else kwargs["encoder_outputs"].shape
|
else kwargs["encoder_outputs"].shape
|
||||||
)
|
)
|
||||||
return encoder_outputs_shape[1] * input_stride
|
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
|
||||||
|
|
||||||
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
||||||
|
|
||||||
@ -950,34 +1016,24 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
|
def _set_return_timestamps(return_timestamps, is_shortform, generation_config):
|
||||||
if return_timestamps is True:
|
if not is_shortform:
|
||||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
|
||||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
|
||||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
|
||||||
)
|
|
||||||
generation_config.return_timestamps = True
|
|
||||||
elif not is_shortform:
|
|
||||||
if return_timestamps is False:
|
if return_timestamps is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
||||||
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
|
||||||
"requires the generation config to have `no_timestamps_token_id` correctly. "
|
|
||||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
|
||||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
|
||||||
"or make sure to pass no more than 3000 mel input features."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
||||||
generation_config.return_timestamps = True
|
return_timestamps = True
|
||||||
else:
|
|
||||||
generation_config.return_timestamps = False
|
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
|
||||||
|
raise ValueError(
|
||||||
|
"You are trying to return timestamps, but the generation config is not properly set. "
|
||||||
|
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||||
|
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_config.return_timestamps = return_timestamps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_language_and_task(language, task, is_multilingual, generation_config):
|
def _set_language_and_task(language, task, is_multilingual, generation_config):
|
||||||
@ -1016,94 +1072,221 @@ class WhisperGenerationMixin:
|
|||||||
)
|
)
|
||||||
generation_config.task = task
|
generation_config.task = task
|
||||||
|
|
||||||
@staticmethod
|
def _retrieve_init_tokens(self, input_features, generation_config, config, num_segment_frames, kwargs):
|
||||||
def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs):
|
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
|
||||||
forced_decoder_ids = None
|
"""short function to replace num with a itr in lst"""
|
||||||
# Legacy code for backward compatibility
|
found = any(i in lst for i in itr)
|
||||||
if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
if found:
|
||||||
forced_decoder_ids = config.forced_decoder_ids
|
lst = [num if i in itr else i for i in lst]
|
||||||
|
else:
|
||||||
|
lst.append(num)
|
||||||
|
return lst
|
||||||
|
|
||||||
|
task = getattr(generation_config, "task", None)
|
||||||
|
language = getattr(generation_config, "language", None)
|
||||||
|
|
||||||
|
if kwargs.get("forced_decoder_ids", None) is not None:
|
||||||
|
forced_decoder_ids = kwargs["forced_decoder_ids"]
|
||||||
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
|
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
forced_decoder_ids = generation_config.forced_decoder_ids
|
||||||
else:
|
|
||||||
forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
|
||||||
|
|
||||||
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
||||||
forced_decoder_ids = []
|
logger.warning_once(
|
||||||
if hasattr(generation_config, "language"):
|
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
||||||
if generation_config.language in generation_config.lang_to_id.keys():
|
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
|
||||||
language_token = generation_config.language
|
|
||||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
|
||||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
|
||||||
elif generation_config.language in TO_LANGUAGE_CODE.values():
|
|
||||||
language_token = f"<|{generation_config.language}|>"
|
|
||||||
else:
|
|
||||||
is_language_code = len(generation_config.language) == 2
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported language: {generation_config.language}. Language should be one of:"
|
|
||||||
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
|
||||||
)
|
|
||||||
if language_token not in generation_config.lang_to_id:
|
|
||||||
raise ValueError(
|
|
||||||
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
|
||||||
"(You should just add it to the generation config)"
|
|
||||||
)
|
|
||||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
|
||||||
else:
|
|
||||||
forced_decoder_ids.append((1, None)) # automatically detect the language
|
|
||||||
|
|
||||||
if hasattr(generation_config, "task"):
|
|
||||||
if generation_config.task in TASK_IDS:
|
|
||||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
|
||||||
)
|
|
||||||
elif hasattr(generation_config, "task_to_id"):
|
|
||||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
|
||||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
|
||||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
|
||||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
|
||||||
|
|
||||||
if forced_decoder_ids is not None:
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
if prompt_ids is not None:
|
|
||||||
if kwargs.get("decoder_start_token_id") is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
|
||||||
)
|
)
|
||||||
prompt_ids = prompt_ids.tolist()
|
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
||||||
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
forced_decoder_ids = config.forced_decoder_ids
|
||||||
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
else:
|
||||||
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
forced_decoder_ids = None
|
||||||
text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :]
|
|
||||||
# Set the decoder_start_token_id to <|startofprev|>
|
|
||||||
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
|
||||||
|
|
||||||
# If the user passes `max_new_tokens`, increase its number to account for the prompt
|
if forced_decoder_ids is not None and task is not None:
|
||||||
if kwargs.get("max_new_tokens", None) is not None:
|
logger.info(
|
||||||
kwargs["max_new_tokens"] += len(text_prompt_ids)
|
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
|
||||||
if kwargs["max_new_tokens"] >= config.max_target_positions:
|
)
|
||||||
raise ValueError(
|
forced_decoder_ids = None
|
||||||
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
|
elif forced_decoder_ids is not None and language is not None:
|
||||||
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
|
logger.info(
|
||||||
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
|
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
|
||||||
f"`max_target_positions` of the Whisper model: {config.max_target_positions}. "
|
)
|
||||||
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
forced_decoder_ids = None
|
||||||
f"so that their combined length is less that {config.max_target_positions}."
|
|
||||||
)
|
init_tokens = [generation_config.decoder_start_token_id]
|
||||||
|
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
||||||
# Reformat the forced_decoder_ids to incorporate the prompt
|
i = 1
|
||||||
non_prompt_forced_decoder_ids = (
|
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
||||||
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
init_tokens += [forced_decoder_ids[0][1]]
|
||||||
|
forced_decoder_ids = forced_decoder_ids[1:]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# TODO(Sanchit): Let's make sure we don't allow incorrectly / weirdly formatted `forced_decoder_ids` after transformers v4.39
|
||||||
|
if len(forced_decoder_ids) > 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}. `forced_decoder_ids` will be passed as a logit processor, but note that this functionality has been deprecated and will throw an error in v4.39.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(Sanchit): set generation_config.forced_decoder_ids to None for v4.39
|
||||||
|
generation_config.forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
|
||||||
|
|
||||||
|
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
||||||
|
if language is not None:
|
||||||
|
if language in generation_config.lang_to_id.keys():
|
||||||
|
language_token = language
|
||||||
|
elif language in TO_LANGUAGE_CODE.keys():
|
||||||
|
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
|
||||||
|
elif language in TO_LANGUAGE_CODE.values():
|
||||||
|
language_token = f"<|{language}|>"
|
||||||
|
else:
|
||||||
|
is_language_code = len(language) == 2
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported language: {language}. Language should be one of:"
|
||||||
|
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
||||||
|
)
|
||||||
|
if language_token not in generation_config.lang_to_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
||||||
|
"(You should just add it to the generation config)"
|
||||||
|
)
|
||||||
|
|
||||||
|
lang_id = generation_config.lang_to_id[language_token]
|
||||||
|
|
||||||
|
# if language is defined it'll overwrite language ids that might have already been defined via the generation_config
|
||||||
|
replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values())
|
||||||
|
elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
|
||||||
|
# language is not defined or intentially set to `None` to trigger language detection
|
||||||
|
lang_ids = self.detect_language(
|
||||||
|
input_features=input_features,
|
||||||
|
encoder_outputs=kwargs.get("encoder_outputs", None),
|
||||||
|
generation_config=generation_config,
|
||||||
|
num_segment_frames=num_segment_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.unique(lang_ids).shape[0] > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple languages detected when trying to predict the most likely target language for transcription. It is currently not supported to transcribe to different languages in a single batch. Please make sure to either force a single language by passing `language='...'` or make sure all input audio is of the same language."
|
||||||
|
)
|
||||||
|
|
||||||
|
lang_id = lang_ids[0].item()
|
||||||
|
|
||||||
|
# append or replace lang_id to init_tokens
|
||||||
|
if len(init_tokens) > 1:
|
||||||
|
init_tokens[1] = lang_id
|
||||||
|
else:
|
||||||
|
init_tokens.append(lang_id)
|
||||||
|
|
||||||
|
if task is not None:
|
||||||
|
if task in TASK_IDS:
|
||||||
|
init_tokens.append(generation_config.task_to_id[generation_config.task])
|
||||||
|
task_id = generation_config.task_to_id[generation_config.task]
|
||||||
|
|
||||||
|
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
|
||||||
|
replace_or_add(init_tokens, task_id, generation_config.task_to_id.values())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
|
||||||
|
elif language is not None and hasattr(generation_config, "task_to_id"):
|
||||||
|
# if language is defined, but no task id is in `init_tokens`, default to transcribe
|
||||||
|
if not any(i in init_tokens for i in generation_config.task_to_id.values()):
|
||||||
|
init_tokens.append(generation_config.task_to_id["transcribe"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
not generation_config.return_timestamps
|
||||||
|
and hasattr(generation_config, "no_timestamps_token_id")
|
||||||
|
and init_tokens[-1] != generation_config.no_timestamps_token_id
|
||||||
|
):
|
||||||
|
init_tokens.append(generation_config.no_timestamps_token_id)
|
||||||
|
elif generation_config.return_timestamps and init_tokens[-1] == generation_config.no_timestamps_token_id:
|
||||||
|
logger.info(
|
||||||
|
"<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
|
||||||
|
)
|
||||||
|
init_tokens = init_tokens[:-1]
|
||||||
|
|
||||||
|
# let's make sure we don't pass `None` tokens as prompt tokens
|
||||||
|
init_tokens = [t for t in init_tokens if t is not None]
|
||||||
|
|
||||||
|
return init_tokens
|
||||||
|
|
||||||
|
def detect_language(
|
||||||
|
self,
|
||||||
|
input_features: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
num_segment_frames: int = 3000,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Detects language from log-mel input features or encoder_outputs
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
|
||||||
|
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
|
||||||
|
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
||||||
|
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
||||||
|
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
||||||
|
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
|
||||||
|
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||||
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||||
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||||
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
|
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||||
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||||
|
passed to generate matching the attributes of `generation_config` will override them. If
|
||||||
|
`generation_config` is not provided, the default will be used, which had the following loading
|
||||||
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||||
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||||
|
default values, whose documentation should be checked to parameterize generation.
|
||||||
|
num_segment_frames (`int`, defaults to 3000):
|
||||||
|
The number of log-mel frames the model expects
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A `torch.LongTensor` representing the detected language ids.
|
||||||
|
"""
|
||||||
|
if input_features is None and encoder_outputs is None:
|
||||||
|
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
|
||||||
|
elif input_features is not None and encoder_outputs is not None:
|
||||||
|
raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
|
||||||
|
elif input_features is not None:
|
||||||
|
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
|
||||||
|
batch_size = input_features.shape[0]
|
||||||
|
elif encoder_outputs is not None:
|
||||||
|
inputs = {"encoder_outputs": encoder_outputs}
|
||||||
|
batch_size = (
|
||||||
|
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_config = generation_config or self.generation_config
|
||||||
|
decoder_input_ids = (
|
||||||
|
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
|
||||||
|
* generation_config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
|
||||||
|
|
||||||
|
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
|
||||||
|
non_lang_mask[list(generation_config.lang_to_id.values())] = False
|
||||||
|
|
||||||
|
logits[:, non_lang_mask] = -np.inf
|
||||||
|
|
||||||
|
lang_ids = logits.argmax(-1)
|
||||||
|
|
||||||
|
return lang_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_decoder_input_ids(prompt_ids, init_tokens, is_shortform, kwargs):
|
||||||
|
decoder_input_ids = kwargs.get("decoder_input_ids", None)
|
||||||
|
if prompt_ids is not None and decoder_input_ids is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot pass both `prompt_ids`: {prompt_ids} and `decoder_input_ids`: {decoder_input_ids}. Passing `decoder_input_ids` is deprecated, consider not passing it."
|
||||||
|
)
|
||||||
|
elif decoder_input_ids is not None and not is_shortform:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot pass both `decoder_input_ids`: {decoder_input_ids} for long-form generation. Consider passing `prompt_ids` instead."
|
||||||
|
)
|
||||||
|
elif decoder_input_ids is not None and is_shortform:
|
||||||
|
warnings.warn(
|
||||||
|
f"You have provided `decoder_input_ids` which will overwrite the `init_tokens` {init_tokens}. This might lead to unexpected behavior. Passing `decoder_input_ids` is deprecated and will be removed in v4.39. Consider passing `prompt_ids` instead.",
|
||||||
|
FutureWarning,
|
||||||
)
|
)
|
||||||
forced_decoder_ids = [
|
|
||||||
*text_prompt_ids,
|
|
||||||
generation_config.decoder_start_token_id,
|
|
||||||
*[token for _, token in non_prompt_forced_decoder_ids],
|
|
||||||
]
|
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_token_ids(generation_config, config, kwargs):
|
def _set_token_ids(generation_config, config, kwargs):
|
||||||
@ -1162,6 +1345,25 @@ class WhisperGenerationMixin:
|
|||||||
else getattr(generation_config, "condition_on_prev_tokens", None)
|
else getattr(generation_config, "condition_on_prev_tokens", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_prompt_condition_type(generation_config, prompt_condition_type):
|
||||||
|
allowed_cond_types = ["first-segment", "all-segments"]
|
||||||
|
|
||||||
|
# default to "first-segment"
|
||||||
|
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
|
||||||
|
|
||||||
|
if prompt_condition_type not in allowed_cond_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_config.prompt_condition_type = prompt_condition_type
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
|
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
|
||||||
condition_on_prev_tokens = (
|
condition_on_prev_tokens = (
|
||||||
@ -1175,7 +1377,7 @@ class WhisperGenerationMixin:
|
|||||||
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames):
|
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames):
|
||||||
if batch_size > 1 and attention_mask is None:
|
if batch_size > 1 and attention_mask is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
||||||
)
|
)
|
||||||
elif batch_size > 1:
|
elif batch_size > 1:
|
||||||
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
||||||
@ -1186,37 +1388,7 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
return max_frames, seek
|
return max_frames, seek
|
||||||
|
|
||||||
@staticmethod
|
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams):
|
||||||
def _retrieve_init_tokens_from_forced_decoder_ids(generation_config):
|
|
||||||
init_tokens = [generation_config.decoder_start_token_id]
|
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
|
||||||
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
|
||||||
i = 1
|
|
||||||
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
|
||||||
init_tokens += [forced_decoder_ids[0][1]]
|
|
||||||
forced_decoder_ids = forced_decoder_ids[1:]
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
return init_tokens
|
|
||||||
|
|
||||||
def _retrieve_logit_processors(
|
|
||||||
self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams
|
|
||||||
):
|
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
|
||||||
if generation_config.return_timestamps is True:
|
|
||||||
last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None
|
|
||||||
if last_forced_decoder_ids == generation_config.no_timestamps_token_id:
|
|
||||||
# remove no_timestamp to be forcefully generated if we want to return timestamps
|
|
||||||
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
|
|
||||||
forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None
|
|
||||||
# Make sure that if list is empty we set it to None
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1
|
|
||||||
|
|
||||||
if generation_config.return_timestamps is True:
|
if generation_config.return_timestamps is True:
|
||||||
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
|
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
|
||||||
logits_processor = (
|
logits_processor = (
|
||||||
@ -1243,7 +1415,7 @@ class WhisperGenerationMixin:
|
|||||||
)
|
)
|
||||||
generation_config.begin_suppress_tokens = None
|
generation_config.begin_suppress_tokens = None
|
||||||
|
|
||||||
if no_speech_threshold is not None and not is_shortform:
|
if generation_config.no_speech_threshold is not None and not is_shortform:
|
||||||
no_speech_detector = WhisperNoSpeechDetection(
|
no_speech_detector = WhisperNoSpeechDetection(
|
||||||
no_speech_token=generation_config.no_timestamps_token_id - 1,
|
no_speech_token=generation_config.no_timestamps_token_id - 1,
|
||||||
begin_index=begin_index,
|
begin_index=begin_index,
|
||||||
@ -1256,11 +1428,12 @@ class WhisperGenerationMixin:
|
|||||||
|
|
||||||
if is_shortform and generation_config.forced_decoder_ids is not None:
|
if is_shortform and generation_config.forced_decoder_ids is not None:
|
||||||
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
|
forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)
|
||||||
# TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after
|
# It's important that the `forced_tokens_proc` processor is appended after
|
||||||
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
|
# the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf
|
||||||
# which would lead to unexpected behavior
|
# which would lead to unexpected behavior
|
||||||
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
|
# The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead
|
||||||
# initialize all of them as `decoder_input_ids`.
|
# initialize all of them as `decoder_input_ids`.
|
||||||
|
# TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore.
|
||||||
logits_processor = (
|
logits_processor = (
|
||||||
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
|
[forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc]
|
||||||
)
|
)
|
||||||
@ -1310,6 +1483,7 @@ class WhisperGenerationMixin:
|
|||||||
current_segments,
|
current_segments,
|
||||||
batch_idx_map,
|
batch_idx_map,
|
||||||
do_condition_on_prev_tokens,
|
do_condition_on_prev_tokens,
|
||||||
|
prompt_ids,
|
||||||
generation_config,
|
generation_config,
|
||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
@ -1328,19 +1502,27 @@ class WhisperGenerationMixin:
|
|||||||
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
|
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
|
||||||
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
||||||
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
||||||
prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text
|
|
||||||
|
|
||||||
bos_token_tensor = prev_start_of_text * one_tensor[0]
|
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
||||||
|
prev_ids = prompt_ids
|
||||||
|
else:
|
||||||
|
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
|
||||||
|
|
||||||
prev_tokens = _pad_to_max_length(
|
prev_tokens = _pad_to_max_length(
|
||||||
active_segments,
|
active_segments,
|
||||||
generation_config.pad_token_id,
|
generation_config.pad_token_id,
|
||||||
padding="left",
|
padding="left",
|
||||||
bos_token_tensor=bos_token_tensor,
|
bos_token_tensor=prev_ids,
|
||||||
cut_off_length=cut_off_length,
|
cut_off_length=cut_off_length,
|
||||||
)
|
)
|
||||||
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
||||||
|
|
||||||
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
|
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
|
||||||
|
elif prompt_ids is not None:
|
||||||
|
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
|
||||||
|
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
||||||
|
# make sure `"decoder_attention_mask"` is not passed to forward
|
||||||
|
kwargs.pop("decoder_attention_mask", None)
|
||||||
else:
|
else:
|
||||||
# make sure `"decoder_attention_mask"` is not passed to forward
|
# make sure `"decoder_attention_mask"` is not passed to forward
|
||||||
kwargs.pop("decoder_attention_mask", None)
|
kwargs.pop("decoder_attention_mask", None)
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
# Original model wasn't trained with timestamps and has incorrect generation config
|
# Original model wasn't trained with timestamps and has incorrect generation config
|
||||||
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||||
|
|
||||||
|
# the audio is 4 seconds long
|
||||||
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||||
|
|
||||||
out = pipe(
|
out = pipe(
|
||||||
@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
out,
|
out,
|
||||||
{
|
{
|
||||||
"chunks": [
|
|
||||||
{"text": "", "timestamp": (18.94, 0.02)},
|
|
||||||
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
|
|
||||||
],
|
|
||||||
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
|
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
|
||||||
|
"chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user