mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Whisper] Refactor whisper (#21252)
* update whisper logit processor * add generate for whisper * remove part of the whisper specific code from pipeline * update logit processes * major update * enforce first timestamp * update generate * add more tests * update new decoding strategy * Apply suggestions from code review * update docstring * fixup * default config will not have multilingual ar * update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
parent
f83135eb76
commit
255257f3ea
@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
probs to `inf` so that they are sampled at their corresponding index.
|
||||
|
||||
Args:
|
||||
begin_index (`int`, *optional*, defaults to 5 ):
|
||||
This indicates to the processor where the first tokens are generated. This is used to differentiate between
|
||||
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the
|
||||
`prompt` tokens are the first 4 tokens.
|
||||
eos_token_id (`int`, *optional*, defaults to 50257):
|
||||
The id of the *end-of-sequence* token.
|
||||
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
|
||||
The id of the `"<|notimestamps|>"` token.
|
||||
max_initial_timestamp (`int`, *optional*, defaults to 1):
|
||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
|
||||
timestamps that are too far in the future.
|
||||
generate_config (`GenerateConfig`):
|
||||
The generate config used to generate the output. The following parameters are required:
|
||||
eos_token_id (`int`, *optional*, defaults to 50257):
|
||||
The id of the *end-of-sequence* token.
|
||||
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
|
||||
The id of the `"<|notimestamps|>"` token.
|
||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||
predicting timestamps that are too far in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
begin_index=5,
|
||||
eos_token_id=50257,
|
||||
no_timestamps_token_id=50363,
|
||||
max_initial_timestamp=1,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.no_timestamps_token_id = no_timestamps_token_id
|
||||
self.timestamp_begin = no_timestamps_token_id + 1
|
||||
self.begin_index = begin_index
|
||||
self.max_initial_timestamp_index = max_initial_timestamp
|
||||
def __init__(self, generate_config): # support for the kwargs
|
||||
self.eos_token_id = generate_config.eos_token_id
|
||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
||||
|
||||
self.begin_index = len(generate_config.forced_decoder_ids) + 1
|
||||
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
|
||||
self.begin_index -= 1
|
||||
if generate_config.is_multilingual:
|
||||
self.begin_index += 1
|
||||
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
scores[:, self.no_timestamps_token_id] = -float("inf")
|
||||
if input_ids.shape[1] == self.begin_index:
|
||||
scores[:, self.timestamp_begin] = 0
|
||||
|
||||
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
|
||||
for k in range(input_ids.shape[0]):
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -1231,6 +1232,150 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config=None,
|
||||
logits_processor=None,
|
||||
stopping_criteria=None,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
synced_gpus=False,
|
||||
return_timestamps=None,
|
||||
task=None,
|
||||
language=None,
|
||||
is_multilingual=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
|
||||
For an overview of generation strategies and code examples, check out the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
||||
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
||||
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
||||
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which had the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
Custom logits processors that complement the default logits processors built from arguments and
|
||||
generation config. If a logit processor is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
||||
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
||||
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
return_timestamps (`bool`, *optional*):
|
||||
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
||||
task (`bool`, *optional*):
|
||||
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
||||
will be updated accordingly.
|
||||
language (`bool`, *optional*):
|
||||
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
|
||||
language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||
is_multilingual (`bool`, *optional*):
|
||||
Whether or not the model is multilingual.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
||||
|
||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.GreedySearchDecoderOnlyOutput`],
|
||||
- [`~generation.SampleDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSearchDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSampleDecoderOnlyOutput`]
|
||||
|
||||
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.GreedySearchEncoderDecoderOutput`],
|
||||
- [`~generation.SampleEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
"""
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
|
||||
if return_timestamps is not None:
|
||||
generation_config.return_timestamps = return_timestamps
|
||||
|
||||
if task is not None:
|
||||
generation_config.task = task
|
||||
|
||||
if is_multilingual is not None:
|
||||
generation_config.is_multilingual = is_multilingual
|
||||
|
||||
if language is not None:
|
||||
generation_config.language = language
|
||||
|
||||
forced_decoder_ids = []
|
||||
|
||||
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
|
||||
if hasattr(generation_config, "language"):
|
||||
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
|
||||
else:
|
||||
forced_decoder_ids.append((1, None))
|
||||
|
||||
if hasattr(generation_config, "task"):
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
||||
else:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
||||
|
||||
if (
|
||||
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
|
||||
) or return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
else:
|
||||
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||
|
||||
if len(forced_decoder_ids) > 0:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
return super().generate(
|
||||
inputs,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
stopping_criteria,
|
||||
prefix_allowed_tokens_fn,
|
||||
synced_gpus,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
|
@ -493,6 +493,23 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||
return normalizer(text)
|
||||
|
||||
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
timestamp_begin = self.all_special_ids[-1] + 1
|
||||
outputs = [[]]
|
||||
for token in token_ids:
|
||||
if token >= timestamp_begin:
|
||||
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
|
||||
def _compute_offsets(self, token_ids, time_precision=0.02):
|
||||
"""
|
||||
Compute offsets for a given tokenized input
|
||||
@ -544,6 +561,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
output_offsets: bool = False,
|
||||
time_precision=0.02,
|
||||
decode_with_timestamps: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@ -561,7 +579,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
output_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||
timestamps.
|
||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
WHether or not to decode with timestamps included in the raw text.
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
@ -571,6 +593,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
if decode_with_timestamps:
|
||||
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
|
||||
# retrieve offsets
|
||||
if output_offsets:
|
||||
offsets = None
|
||||
|
@ -31,8 +31,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if return_timestamps is not None:
|
||||
forward_params["return_timestamps"] = return_timestamps
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
if self.model.config.model_type == "whisper":
|
||||
# Whisper is highly specific, if we want timestamps, we need to
|
||||
# force whisper to output timestamp tokens, which means we need
|
||||
# to set this variable to prevent `no_timestamp_token` to be
|
||||
# used in the decoder.
|
||||
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
|
||||
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
is_last = model_inputs.pop("is_last")
|
||||
|
||||
if self.type == "seq2seq":
|
||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||
encoder = self.model.get_encoder()
|
||||
# Consume values so we can let extra information flow freely through
|
||||
# the pipeline (important for `partial` in microphone)
|
||||
@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
**generate_kwargs,
|
||||
)
|
||||
out = {"tokens": tokens}
|
||||
elif self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
tokens = self.model.generate(
|
||||
input_features=model_inputs.pop("input_features"),
|
||||
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
|
||||
**generate_kwargs,
|
||||
)
|
||||
out = {"tokens": tokens}
|
||||
if stride is not None:
|
||||
out["stride"] = stride
|
||||
if self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
if stride is not None:
|
||||
out["stride"] = stride
|
||||
|
||||
else:
|
||||
stride = model_inputs.pop("stride", None)
|
||||
|
@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(len(vocab_keys), 50364)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 50257)
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
|
||||
@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# test `decode_with_offsets`
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True)
|
||||
self.assertEqual(
|
||||
output,
|
||||
"<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom"
|
||||
" paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut"
|
||||
" Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash"
|
||||
" his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>",
|
||||
)
|
||||
# test a single sequence with timestamps
|
||||
# fmt: off
|
||||
INPUT_TOKENS = [
|
||||
|
@ -28,7 +28,6 @@ from transformers import (
|
||||
Speech2TextForConditionalGeneration,
|
||||
Wav2Vec2ForCTC,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
|
||||
},
|
||||
)
|
||||
pipe = pipeline(
|
||||
model="openai/whisper-small",
|
||||
return_timestamps=True,
|
||||
)
|
||||
|
||||
output = pipe(array, chunk_length_s=10)
|
||||
self.assertDictEqual(
|
||||
@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||
)
|
||||
output = speech_recognizer(filename, return_timestamps=True)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||
),
|
||||
"timestamp": (0.0, 5.44),
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output_2 = speech_recognizer_2(filename)
|
||||
self.assertEqual(output, output_2)
|
||||
|
||||
processor = WhisperProcessor(feature_extractor, tokenizer)
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
|
||||
# either use generate_kwargs or set the model's generation_config
|
||||
# model.generation_config.task = "transcribe"
|
||||
# model.generation_config.lang = "<|it|>"
|
||||
speech_translator = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
|
||||
)
|
||||
output_3 = speech_translator(filename)
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||
|
Loading…
Reference in New Issue
Block a user