[Whisper] Refactor whisper (#21252)

* update whisper logit processor

* add generate for whisper

* remove part of the whisper specific code from pipeline

* update logit processes

* major update

* enforce first timestamp

* update generate

* add more tests

* update new decoding strategy

* Apply suggestions from code review

* update docstring

* fixup

* default config will not have multilingual ar

* update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
Arthur 2023-01-25 13:09:43 +01:00 committed by GitHub
parent f83135eb76
commit 255257f3ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 231 additions and 55 deletions

View File

@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
probs to `inf` so that they are sampled at their corresponding index.
Args:
begin_index (`int`, *optional*, defaults to 5 ):
This indicates to the processor where the first tokens are generated. This is used to differentiate between
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the
`prompt` tokens are the first 4 tokens.
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
timestamps that are too far in the future.
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""
def __init__(
self,
begin_index=5,
eos_token_id=50257,
no_timestamps_token_id=50363,
max_initial_timestamp=1,
):
self.eos_token_id = eos_token_id
self.no_timestamps_token_id = no_timestamps_token_id
self.timestamp_begin = no_timestamps_token_id + 1
self.begin_index = begin_index
self.max_initial_timestamp_index = max_initial_timestamp
def __init__(self, generate_config): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 1
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
if generate_config.is_multilingual:
self.begin_index += 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index:
scores[:, self.timestamp_begin] = 0
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -1231,6 +1232,150 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
encoder_attentions=outputs.encoder_attentions,
)
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=False,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs
):
"""
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`bool`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
if language is not None:
generation_config.language = language
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
forced_decoder_ids.append((1, None))
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,

View File

@ -493,6 +493,23 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]]
for token in token_ids:
if token >= timestamp_begin:
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02):
"""
Compute offsets for a given tokenized input
@ -544,6 +561,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces: bool = True,
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
**kwargs
) -> str:
"""
@ -561,7 +579,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text.
Returns:
`str`: The decoded sentence.
"""
@ -571,6 +593,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
# retrieve offsets
if output_offsets:
offsets = None

View File

@ -31,8 +31,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_torch_available():
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
# force whisper to output timestamp tokens, which means we need
# to set this variable to prevent `no_timestamp_token` to be
# used in the decoder.
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
return preprocess_params, forward_params, postprocess_params
@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
is_last = model_inputs.pop("is_last")
if self.type == "seq2seq":
if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
**generate_kwargs,
)
out = {"tokens": tokens}
elif self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
tokens = self.model.generate(
input_features=model_inputs.pop("input_features"),
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
**generate_kwargs,
)
out = {"tokens": tokens}
if stride is not None:
out["stride"] = stride
if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
if stride is not None:
out["stride"] = stride
else:
stride = model_inputs.pop("stride", None)

View File

@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(len(vocab_keys), 50364)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50257)
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
},
],
)
# test `decode_with_offsets`
output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True)
self.assertEqual(
output,
"<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom"
" paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut"
" Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash"
" his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>",
)
# test a single sequence with timestamps
# fmt: off
INPUT_TOKENS = [

View File

@ -28,7 +28,6 @@ from transformers import (
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
},
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.44),
}
],
},
)
@slow
@require_torch
@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
# either use generate_kwargs or set the model's generation_config
# model.generation_config.task = "transcribe"
# model.generation_config.lang = "<|it|>"
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})