feat: Whisper prompting (#22496)

* initial working additions

* clean and rename, add cond stripping initial prompt to decode

* cleanup, edit create_initial_prompt_ids, add tests

* repo consistency, flip order of conditional

* fix error, move the processor fn to the tokenizer

* repo consistency, update test ids to corresponding tokenizer

* use convert_tokens_to_ids not get_vocab...

* use actual conditional in generate

* make sytle

* initial address comments

* initial working add new params to pipeline

* first draft of sequential generation for condition_on_previous_text

* add/update tests, make compatible with timestamps

* make compatible with diff. input kwargs and max length

* add None check

* add temperature check

* flip temp check operand

* refocusing to prev pr scope

* remove the params too

* make style

* edits, move max length incorporating prompt to whisper

* address comments

* remove asr pipeline prompt decoding, fix indexing

* address comments (more tests, validate prompt)

* un-comment out tests (from debug)

* remove old comment

* address comments

* fix typo

* remove timestamp token from test

* make style

* cleanup

* copy method to fast tokenizer, set max_new_tokens for test

* prompt_ids type just pt

* address Amy's comments

* make style
This commit is contained in:
Connor Henderson 2023-05-19 04:33:11 -04:00 committed by GitHub
parent a7920065f2
commit 2acedf4721
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 272 additions and 15 deletions

View File

@ -34,7 +34,12 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
@ -1464,6 +1469,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
task=None,
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
**kwargs,
):
"""
@ -1521,6 +1527,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
prompt_ids (`torch.Tensor`, *optional*):
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -1567,8 +1578,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if task is not None:
generation_config.task = task
forced_decoder_ids = []
if task is not None or language is not None:
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
else:
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
forced_decoder_ids = []
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
@ -1593,27 +1617,48 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
elif hasattr(generation_config, "task_to_id"):
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
generation_config.forced_decoder_ids = forced_decoder_ids
if prompt_ids is not None:
if kwargs.get("decoder_start_token_id") is not None:
raise ValueError(
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
# Update the max generation length to include the prompt
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
default_max_length = generation_config.max_new_tokens or generation_config.max_length
non_prompt_max_length = specified_max_length or default_max_length
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
)
forced_decoder_ids = [
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids],
]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
return super().generate(
inputs,
generation_config,

View File

@ -16,6 +16,7 @@
Speech processor class for Whisper
"""
from ...processing_utils import ProcessorMixin
@ -91,3 +92,6 @@ class WhisperProcessor(ProcessorMixin):
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def get_prompt_ids(self, text: str, return_tensors="np"):
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)

View File

@ -606,6 +606,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
@ -714,6 +719,31 @@ class WhisperTokenizer(PreTrainedTokenizer):
time_precision=time_precision,
)
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]
@staticmethod
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []
return token_ids
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""

View File

@ -312,6 +312,11 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return text
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
if kwargs["skip_special_tokens"]:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)
text = super()._decode(*args, **kwargs)
if normalize:
@ -485,3 +490,30 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return_language=return_language,
time_precision=time_precision,
)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]
@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []
return token_ids

View File

@ -1013,6 +1013,48 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
def test_generate_with_prompt_ids_and_task_and_language(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"]
prompt_ids = np.arange(5)
language = "<|de|>"
task = "translate"
lang_id = 6
task_id = 7
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
model.generation_config.__setattr__("task_to_id", {task: task_id})
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
lang_id,
task_id,
]
for row in output.tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
input_features = input_dict["input_features"]
prompt_ids = np.asarray(range(5))
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
output = model.generate(
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
)
expected_output_start = [
*prompt_ids.tolist(),
model.generation_config.decoder_start_token_id,
*[token for _rank, token in forced_decoder_ids],
]
for row in output.tolist():
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
@require_torch
@require_torchaudio
@ -1429,6 +1471,60 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
@slow
def test_generate_with_prompt_ids(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(4)[-1:]
input_features = processor(input_speech, return_tensors="pt").input_features
output_without_prompt = model.generate(input_features)
prompt_ids = processor.get_prompt_ids("Leighton")
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt)
self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt)
@slow
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features
task = "translate"
language = "de"
expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
prompt = "test prompt"
prompt_ids = processor.get_prompt_ids(prompt)
output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
text = processor.decode(output[0])
self.assertTrue(prompt in text)
self.assertTrue(all([token in text for token in expected_tokens]))
@slow
def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features
prompt = "test prompt"
prompt_ids = processor.get_prompt_ids(prompt)
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None
output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
text = processor.decode(output[0])
self.assertTrue(prompt in text)
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:

View File

@ -16,6 +16,8 @@ import shutil
import tempfile
import unittest
import pytest
from transformers import WhisperTokenizer, is_speech_available
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase):
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
def test_get_prompt_ids(self):
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
decoded_prompt = processor.tokenizer.decode(prompt_ids)
self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")
def test_empty_get_prompt_ids(self):
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
prompt_ids = processor.get_prompt_ids("")
decoded_prompt = processor.tokenizer.decode(prompt_ids)
self.assertListEqual(prompt_ids.tolist(), [50360, 220])
self.assertEqual(decoded_prompt, "<|startofprev|> ")
def test_get_prompt_ids_with_special_tokens(self):
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
def _test_prompt_error_raised_helper(prompt, special_token):
with pytest.raises(ValueError) as excinfo:
processor.get_prompt_ids(prompt)
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
self.assertEqual(expected, str(excinfo.value))
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")

View File

@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
def test_skip_special_tokens_skips_prompt_ids(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# fmt: off
encoded_input = [
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
]
# fmt: on
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"