From 255257f3ea0862cbb92ea9fa1113cbee1898aadd Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:09:43 +0100 Subject: [PATCH] [Whisper] Refactor whisper (#21252) * update whisper logit processor * add generate for whisper * remove part of the whisper specific code from pipeline * update logit processes * major update * enforce first timestamp * update generate * add more tests * update new decoding strategy * Apply suggestions from code review * update docstring * fixup * default config will not have multilingual ar * update expected tokenizer size, see pull on the hub for whisper-tiny --- src/transformers/generation/logits_process.py | 45 +++--- .../models/whisper/modeling_whisper.py | 145 ++++++++++++++++++ .../models/whisper/tokenization_whisper.py | 26 +++- .../pipelines/automatic_speech_recognition.py | 28 +--- .../whisper/test_tokenization_whisper.py | 12 +- ..._pipelines_automatic_speech_recognition.py | 30 +++- 6 files changed, 231 insertions(+), 55 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 15e35bd21ed..82dff5eab95 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): probs to `inf` so that they are sampled at their corresponding index. Args: - begin_index (`int`, *optional*, defaults to 5 ): - This indicates to the processor where the first tokens are generated. This is used to differentiate between - the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the - `prompt` tokens are the first 4 tokens. - eos_token_id (`int`, *optional*, defaults to 50257): - The id of the *end-of-sequence* token. - no_timestamps_token_id (`int`, *optional*, defaults to 50363): - The id of the `"<|notimestamps|>"` token. - max_initial_timestamp (`int`, *optional*, defaults to 1): - Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting - timestamps that are too far in the future. + generate_config (`GenerateConfig`): + The generate config used to generate the output. The following parameters are required: + eos_token_id (`int`, *optional*, defaults to 50257): + The id of the *end-of-sequence* token. + no_timestamps_token_id (`int`, *optional*, defaults to 50363): + The id of the `"<|notimestamps|>"` token. + max_initial_timestamp_index (`int`, *optional*, defaults to 1): + Used to set the maximum value of the initial timestamp. This is used to prevent the model from + predicting timestamps that are too far in the future. """ - def __init__( - self, - begin_index=5, - eos_token_id=50257, - no_timestamps_token_id=50363, - max_initial_timestamp=1, - ): - self.eos_token_id = eos_token_id - self.no_timestamps_token_id = no_timestamps_token_id - self.timestamp_begin = no_timestamps_token_id + 1 - self.begin_index = begin_index - self.max_initial_timestamp_index = max_initial_timestamp + def __init__(self, generate_config): # support for the kwargs + self.eos_token_id = generate_config.eos_token_id + self.no_timestamps_token_id = generate_config.no_timestamps_token_id + self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + + self.begin_index = len(generate_config.forced_decoder_ids) + 1 + if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: + self.begin_index -= 1 + if generate_config.is_multilingual: + self.begin_index += 1 + self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index def __call__(self, input_ids, scores): # suppress <|notimestamps|> which is handled by without_timestamps scores[:, self.no_timestamps_token_id] = -float("inf") + if input_ids.shape[1] == self.begin_index: + scores[:, self.timestamp_begin] = 0 # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly for k in range(input_ids.shape[0]): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index fb03b63de1f..38bfe133a07 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation.logits_process import WhisperTimeStampLogitsProcessor from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1231,6 +1232,150 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): encoder_attentions=outputs.encoder_attentions, ) + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + **kwargs + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + return_timestamps (`bool`, *optional*): + Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. + task (`bool`, *optional*): + Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` + will be updated accordingly. + language (`bool`, *optional*): + Language token to use for generation, should be in the form `<|en|>`. You can find all the possible + language tokens in the `model.generation_config.lang_to_id` dictionary. + is_multilingual (`bool`, *optional*): + Whether or not the model is multilingual. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + generation_config.return_timestamps = return_timestamps + + if task is not None: + generation_config.task = task + + if is_multilingual is not None: + generation_config.is_multilingual = is_multilingual + + if language is not None: + generation_config.language = language + + forced_decoder_ids = [] + + if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: + if hasattr(generation_config, "language"): + forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) + else: + forced_decoder_ids.append((1, None)) + + if hasattr(generation_config, "task"): + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) + + if ( + hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps + ) or return_timestamps: + logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] + else: + if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if len(forced_decoder_ids) > 0: + generation_config.forced_decoder_ids = forced_decoder_ids + + return super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + **kwargs, + ) + def prepare_inputs_for_generation( self, decoder_input_ids, diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 04bacbd312d..102bcab53a1 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -493,6 +493,23 @@ class WhisperTokenizer(PreTrainedTokenizer): normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) return normalizer(text) + def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str: + """ + Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes + given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + timestamp_begin = self.all_special_ids[-1] + 1 + outputs = [[]] + for token in token_ids: + if token >= timestamp_begin: + timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs] + return "".join(outputs) + def _compute_offsets(self, token_ids, time_precision=0.02): """ Compute offsets for a given tokenized input @@ -544,6 +561,7 @@ class WhisperTokenizer(PreTrainedTokenizer): clean_up_tokenization_spaces: bool = True, output_offsets: bool = False, time_precision=0.02, + decode_with_timestamps: bool = False, **kwargs ) -> str: """ @@ -561,7 +579,11 @@ class WhisperTokenizer(PreTrainedTokenizer): Whether or not to clean up the tokenization spaces. kwargs (additional keyword arguments, *optional*): Will be passed to the underlying model specific decode method. - + output_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output the offsets of the tokens. This should only be set if the model predicted + timestamps. + decode_with_timestamps (`bool`, *optional*, defaults to `False`): + WHether or not to decode with timestamps included in the raw text. Returns: `str`: The decoded sentence. """ @@ -571,6 +593,8 @@ class WhisperTokenizer(PreTrainedTokenizer): clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) + if decode_with_timestamps: + text = self._decode_with_timestamps(token_ids, time_precision=time_precision) # retrieve offsets if output_offsets: offsets = None diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index dd97449ca8a..a409ff21c3d 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -31,8 +31,6 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) if is_torch_available(): - from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor - from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING @@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if return_timestamps is not None: forward_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps - if self.model.config.model_type == "whisper": - # Whisper is highly specific, if we want timestamps, we need to - # force whisper to output timestamp tokens, which means we need - # to set this variable to prevent `no_timestamp_token` to be - # used in the decoder. - if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}): - forward_params["generate_kwargs"]["forced_decoder_ids"] = None return preprocess_params, forward_params, postprocess_params @@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} - + if return_timestamps and self.type == "seq2seq_whisper": + generate_kwargs["return_timestamps"] = return_timestamps is_last = model_inputs.pop("is_last") - if self.type == "seq2seq": + if self.type in {"seq2seq", "seq2seq_whisper"}: encoder = self.model.get_encoder() # Consume values so we can let extra information flow freely through # the pipeline (important for `partial` in microphone) @@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): **generate_kwargs, ) out = {"tokens": tokens} - elif self.type == "seq2seq_whisper": - stride = model_inputs.pop("stride", None) - tokens = self.model.generate( - input_features=model_inputs.pop("input_features"), - logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None, - **generate_kwargs, - ) - out = {"tokens": tokens} - if stride is not None: - out["stride"] = stride + if self.type == "seq2seq_whisper": + stride = model_inputs.pop("stride", None) + if stride is not None: + out["stride"] = stride else: stride = model_inputs.pop("stride", None) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index c9aebb7f328..7cf6bb627fd 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(len(vocab_keys), 50364) def test_vocab_size(self): - self.assertEqual(self.get_tokenizer().vocab_size, 50257) + self.assertEqual(self.get_tokenizer().vocab_size, 50258) def test_full_tokenizer(self): tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname) @@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): }, ], ) - + # test `decode_with_offsets` + output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True) + self.assertEqual( + output, + "<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom" + " paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut" + " Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash" + " his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>", + ) # test a single sequence with timestamps # fmt: off INPUT_TOKENS = [ diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index a1204eb9f94..59b979cb690 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -28,7 +28,6 @@ from transformers import ( Speech2TextForConditionalGeneration, Wav2Vec2ForCTC, WhisperForConditionalGeneration, - WhisperProcessor, ) from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines.audio_utils import chunk_bytes_iter @@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel "chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}], }, ) - pipe = pipeline( - model="openai/whisper-small", - return_timestamps=True, - ) output = pipe(array, chunk_length_s=10) self.assertDictEqual( @@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."}, ) + output = speech_recognizer(filename, return_timestamps=True) + self.assertEqual( + output, + { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "chunks": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 5.44), + } + ], + }, + ) @slow @require_torch @@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel output_2 = speech_recognizer_2(filename) self.assertEqual(output, output_2) - processor = WhisperProcessor(feature_extractor, tokenizer) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it") + # either use generate_kwargs or set the model's generation_config + # model.generation_config.task = "transcribe" + # model.generation_config.lang = "<|it|>" speech_translator = AutomaticSpeechRecognitionPipeline( - model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + model=model, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + generate_kwargs={"task": "transcribe", "language": "<|it|>"}, ) output_3 = speech_translator(filename) self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})