mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
feat: Whisper prompting (#22496)
* initial working additions * clean and rename, add cond stripping initial prompt to decode * cleanup, edit create_initial_prompt_ids, add tests * repo consistency, flip order of conditional * fix error, move the processor fn to the tokenizer * repo consistency, update test ids to corresponding tokenizer * use convert_tokens_to_ids not get_vocab... * use actual conditional in generate * make sytle * initial address comments * initial working add new params to pipeline * first draft of sequential generation for condition_on_previous_text * add/update tests, make compatible with timestamps * make compatible with diff. input kwargs and max length * add None check * add temperature check * flip temp check operand * refocusing to prev pr scope * remove the params too * make style * edits, move max length incorporating prompt to whisper * address comments * remove asr pipeline prompt decoding, fix indexing * address comments (more tests, validate prompt) * un-comment out tests (from debug) * remove old comment * address comments * fix typo * remove timestamp token from test * make style * cleanup * copy method to fast tokenizer, set max_new_tokens for test * prompt_ids type just pt * address Amy's comments * make style
This commit is contained in:
parent
a7920065f2
commit
2acedf4721
@ -34,7 +34,12 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_whisper import WhisperConfig
|
||||
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
||||
|
||||
@ -1464,6 +1469,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
task=None,
|
||||
language=None,
|
||||
is_multilingual=None,
|
||||
prompt_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1521,6 +1527,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||
is_multilingual (`bool`, *optional*):
|
||||
Whether or not the model is multilingual.
|
||||
prompt_ids (`torch.Tensor`, *optional*):
|
||||
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
|
||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -1567,8 +1578,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
if task is not None:
|
||||
generation_config.task = task
|
||||
|
||||
forced_decoder_ids = []
|
||||
if task is not None or language is not None:
|
||||
forced_decoder_ids = None
|
||||
|
||||
# Legacy code for backward compatibility
|
||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
elif (
|
||||
hasattr(self.generation_config, "forced_decoder_ids")
|
||||
and self.generation_config.forced_decoder_ids is not None
|
||||
):
|
||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
||||
else:
|
||||
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
|
||||
|
||||
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
||||
forced_decoder_ids = []
|
||||
if hasattr(generation_config, "language"):
|
||||
if generation_config.language in generation_config.lang_to_id.keys():
|
||||
language_token = generation_config.language
|
||||
@ -1593,27 +1617,48 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
raise ValueError(
|
||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
||||
)
|
||||
else:
|
||||
elif hasattr(generation_config, "task_to_id"):
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||
|
||||
# Legacy code for backward compatibility
|
||||
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
elif (
|
||||
hasattr(self.generation_config, "forced_decoder_ids")
|
||||
and self.generation_config.forced_decoder_ids is not None
|
||||
):
|
||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
||||
if forced_decoder_ids is not None:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
if prompt_ids is not None:
|
||||
if kwargs.get("decoder_start_token_id") is not None:
|
||||
raise ValueError(
|
||||
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
||||
)
|
||||
prompt_ids = prompt_ids.tolist()
|
||||
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
||||
# Set the decoder_start_token_id to <|startofprev|>
|
||||
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
||||
|
||||
# Update the max generation length to include the prompt
|
||||
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
|
||||
default_max_length = generation_config.max_new_tokens or generation_config.max_length
|
||||
non_prompt_max_length = specified_max_length or default_max_length
|
||||
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
|
||||
|
||||
# Reformat the forced_decoder_ids to incorporate the prompt
|
||||
non_prompt_forced_decoder_ids = (
|
||||
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
||||
)
|
||||
forced_decoder_ids = [
|
||||
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
||||
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
||||
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
|
||||
generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in non_prompt_forced_decoder_ids],
|
||||
]
|
||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
if generation_config.return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
|
||||
if len(forced_decoder_ids) > 0:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
return super().generate(
|
||||
inputs,
|
||||
generation_config,
|
||||
|
@ -16,6 +16,7 @@
|
||||
Speech processor class for Whisper
|
||||
"""
|
||||
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
@ -91,3 +92,6 @@ class WhisperProcessor(ProcessorMixin):
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
|
||||
|
@ -606,6 +606,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
) -> str:
|
||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||
|
||||
if skip_special_tokens:
|
||||
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
||||
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||
@ -714,6 +719,31 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
time_precision=time_precision,
|
||||
)
|
||||
|
||||
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
||||
|
||||
# Check for special tokens
|
||||
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
|
||||
if special_token_id is not None:
|
||||
token = self.convert_ids_to_tokens(special_token_id)
|
||||
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
|
||||
|
||||
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||
return batch_encoding["input_ids"]
|
||||
|
||||
@staticmethod
|
||||
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
||||
if has_prompt:
|
||||
if decoder_start_token_id in token_ids:
|
||||
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||
else:
|
||||
return []
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||
"""
|
||||
|
@ -312,6 +312,11 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
return text
|
||||
|
||||
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
||||
if kwargs["skip_special_tokens"]:
|
||||
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)
|
||||
|
||||
text = super()._decode(*args, **kwargs)
|
||||
|
||||
if normalize:
|
||||
@ -485,3 +490,30 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
return_language=return_language,
|
||||
time_precision=time_precision,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
|
||||
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
||||
|
||||
# Check for special tokens
|
||||
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
|
||||
if special_token_id is not None:
|
||||
token = self.convert_ids_to_tokens(special_token_id)
|
||||
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
|
||||
|
||||
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||
return batch_encoding["input_ids"]
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
|
||||
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
||||
if has_prompt:
|
||||
if decoder_start_token_id in token_ids:
|
||||
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||
else:
|
||||
return []
|
||||
|
||||
return token_ids
|
||||
|
@ -1013,6 +1013,48 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@ -1429,6 +1471,60 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(4)[-1:]
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
output_without_prompt = model.generate(input_features)
|
||||
prompt_ids = processor.get_prompt_ids("Leighton")
|
||||
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
|
||||
|
||||
expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||
expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||
self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt)
|
||||
self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt)
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
task = "translate"
|
||||
language = "de"
|
||||
expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
|
||||
prompt = "test prompt"
|
||||
prompt_ids = processor.get_prompt_ids(prompt)
|
||||
|
||||
output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
|
||||
text = processor.decode(output[0])
|
||||
|
||||
self.assertTrue(prompt in text)
|
||||
self.assertTrue(all([token in text for token in expected_tokens]))
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
prompt = "test prompt"
|
||||
prompt_ids = processor.get_prompt_ids(prompt)
|
||||
|
||||
model.generation_config.forced_decoder_ids = None
|
||||
model.config.forced_decoder_ids = None
|
||||
|
||||
output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
|
||||
text = processor.decode(output[0])
|
||||
|
||||
self.assertTrue(prompt in text)
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
@ -16,6 +16,8 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import WhisperTokenizer, is_speech_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
||||
|
||||
@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase):
|
||||
|
||||
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
||||
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||
|
||||
def test_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")
|
||||
|
||||
def test_empty_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 220])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> ")
|
||||
|
||||
def test_get_prompt_ids_with_special_tokens(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
|
||||
def _test_prompt_error_raised_helper(prompt, special_token):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
processor.get_prompt_ids(prompt)
|
||||
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
|
||||
self.assertEqual(expected, str(excinfo.value))
|
||||
|
||||
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
|
||||
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
|
||||
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
|
||||
|
@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
|
||||
def test_skip_special_tokens_skips_prompt_ids(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
# fmt: off
|
||||
encoded_input = [
|
||||
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
|
||||
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
|
||||
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
|
||||
]
|
||||
# fmt: on
|
||||
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
|
||||
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
|
||||
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||
)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
Loading…
Reference in New Issue
Block a user