Whisper Timestamp processor and prediction (#20620)

* add draft logit processor

* add template functions

* update timesapmt processor parameters

* draft script

* simplify code

* cleanup

* fixup and clean

* update pipeline

* style

* clean up previous idea

* add tokenization utils

* update tokenizer and asr output

* fit whisper type

* style and update test

* clean test

* style test

* update tests

* update error test

* udpate code (not based on review yet)

* update tokenization

* update asr pipeline

* update code

* cleanup and update test

* fmt

* remove text verificatino

* cleanup

* cleanup

* add model test

* update tests

* update code add docstring

* update code and add docstring

* fix pipeline tests

* add draft logit processor

add template functions

update timesapmt processor parameters

draft script

simplify code

cleanup

fixup and clean

update pipeline

style

clean up previous idea

add tokenization utils

update tokenizer and asr output

fit whisper type

style and update test

clean test

style test

update tests

update error test

udpate code (not based on review yet)

update tokenization

update asr pipeline

update code

cleanup and update test

fmt

remove text verificatino

cleanup

cleanup

add model test

update tests

update code add docstring

update code and add docstring

fix pipeline tests

* Small update.

* Fixup.

* Tmp.

* More support.

* Making `forced_decoder_ids` non mandatory for users to set.

* update and fix first bug

* properly process sequence right after merge if last

* tofo

* allow list inputs + compute begin index better

* start adding tests

* add the 3 edge cases

* style

* format sequences

* fixup

* update

* update

* style

* test passes, edge cases should be good

* update last value

* remove Trie

* update tests and expec ted values

* handle bigger chunk_length

* clean tests a bit

* refactor chunk iter and clean pipeline

* update tests

* style

* refactor chunk iter and clean pipeline

* upade

* resolve comments

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* take stride right into account

* update test expected values

* Update code based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
Arthur 2023-01-17 15:50:09 +01:00 committed by GitHub
parent 25ddd91b24
commit bb300ac686
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 737 additions and 40 deletions

View File

@ -801,3 +801,67 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
scores[:, :] = -float("inf")
scores[:, current_token] = 0
return scores
class WhisperTimeStampLogitsProcessor(LogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.
Args:
begin_index (`int`, *optional*, defaults to 5 ):
This indicates to the processor where the first tokens are generated. This is used to differentiate between
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the
`prompt` tokens are the first 4 tokens.
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
timestamps that are too far in the future.
"""
def __init__(
self,
begin_index=5,
eos_token_id=50257,
no_timestamps_token_id=50363,
max_initial_timestamp=1,
):
self.eos_token_id = eos_token_id
self.no_timestamps_token_id = no_timestamps_token_id
self.timestamp_begin = no_timestamps_token_id + 1
self.begin_index = begin_index
self.max_initial_timestamp_index = max_initial_timestamp
def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = [t for t in input_ids[k, self.begin_index :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
scores[k, self.timestamp_begin :] = -float("inf")
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1)
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
scores[k, : self.timestamp_begin] = -float("inf")
return scores

View File

@ -17,6 +17,8 @@ import json
import os
from typing import List, Optional, Tuple, Union
import numpy as np
import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
@ -488,6 +490,91 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
def _compute_offsets(self, token_ids, time_precision=0.02):
"""
Compute offsets for a given tokenized input
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
offsets = []
token_ids = np.array(token_ids)
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
raise ValueError("Can only process a single input at a time")
timestamp_begin = self.all_special_ids[-1] + 1
timestamp_tokens = token_ids >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1:
# either there are no timestamps or there are no consecutive ones
return []
elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive:
# we add the final timestamp if it is not already in the list
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
last_slice = np.where(timestamp_tokens)[0][0]
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
offsets.append(
{
"text": self._decode(sliced_tokens),
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)
last_slice = current_slice
return offsets
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
output_offsets: bool = False,
time_precision=0.02,
**kwargs
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
text = super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
# retrieve offsets
if output_offsets:
offsets = None
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
return {"text": text, "offsets": offsets}
return text
def _decode(
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
) -> str:

View File

@ -31,6 +31,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_torch_available():
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
@ -54,7 +56,7 @@ def rescale_stride(stride, ratio):
return new_strides
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
@ -66,20 +68,135 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if processed_len != chunk.shape[-1]:
ratio = processed_len / chunk_len
if ratio != 1:
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed}
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
"""
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
properly compute the final `offset`.
"""
# index of the first timestamp token
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
items = []
# approximation of the token to time ratio : ~0.2seconds
time_precision = feature_extractor.chunk_length / max_source_positions
time = 0
actual_offset = 0
for seq_idx, item in enumerate(sequences):
sequence, stride = item
if isinstance(sequence, list):
sequence = np.array(sequence)
chunk_len, stride_left, stride_right = stride
sequence = sequence.squeeze(0)
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
begin_idx = np.where(sequence == timestamp_begin)[0].item() if timestamp_begin in sequence else 0
sequence = sequence[begin_idx:]
if seq_idx != 0:
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
timestamp_tokens = np.where(sequence >= timestamp_begin)[0][1::2]
if len(timestamp_tokens) >= 1:
# if a big chunk lenght is used, we need to check all of the previous items
best_match = 0
sliced_sequence = []
for idx, previous_sequence in enumerate(reversed(items)):
previous_tokens = previous_sequence[1:-1]
if len(previous_tokens) > 0:
index_left, index_right, match_length = _fast_find_longest_common_sequence(
sequence, previous_tokens
)
# don't do anything if only 1 token was matched
if match_length > 1 and match_length > best_match:
best_match = match_length
best_idx = idx
end_of_curr_sequence_idx = (
np.where(sequence[index_left:] >= timestamp_begin)[0][0] + 1 + index_left
)
sliced_sequence = sequence[index_left:end_of_curr_sequence_idx]
# if all the tokens are matched, suffix
if index_left == 0 and match_length == len(previous_tokens):
sliced_sequence[-1] = previous_sequence[-1]
# if part of the previous sequence is not taken
elif index_left > 0:
# let's insert the missing part of the previous sequence
sliced_sequence = np.insert(sliced_sequence, 0, previous_sequence[: index_right + 1])
sliced_sequence[-1] += offset
if len(sliced_sequence) > 0:
items[len(items) - best_idx - 1] = sliced_sequence
items = items[: len(items) - best_idx]
sequence = sequence[end_of_curr_sequence_idx:]
actual_offset = items[-1][-1] - timestamp_begin
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
# take the last timestamp of the previous chunk
for current_slice in consecutive:
sliced_tokens = sequence[last_slice:current_slice]
# set correct timestamps
sliced_tokens[0] += actual_offset
sliced_tokens[-1] += actual_offset
items.append(sliced_tokens) # correct final sequence
last_slice = current_slice
# check if we have a non consecutive timestamp at the end
if np.where(timestamp_tokens)[0][-1] != current_slice:
# offset = items[-1][-1] if len(items) > 0 else timestamp_begin
sliced_tokens = sequence[current_slice : np.where(timestamp_tokens)[0][-1] + 1]
sliced_tokens[0] += actual_offset
sliced_tokens[-1] += actual_offset
items.append(sliced_tokens)
else:
timestamps = sequence[timestamp_tokens.nonzero()[0].flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_idx = np.argwhere(sequence == timestamps[-1])[0][0]
sliced_sequence = sequence[: last_idx + 1]
duration = sliced_sequence[-1] - sliced_sequence[0]
# We need to discard the previous timing information
sliced_sequence[0] = items[-1][-1]
sliced_sequence[-1] = items[-1][-1] + duration
items.append(sliced_sequence)
# The beginning time of the next chunk
time += chunk_len
result = []
for i in range(len(items)):
result += items[i].tolist()
return result
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
seq_len_left = len(sequence_left)
seq_len_right = len(sequence_right)
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
longest = 0
for i in range(seq_len_left):
for j in range(seq_len_right):
if sequence_left[i] == sequence_right[j]:
previous_counter = counter[i][j] + 1
counter[i + 1][j + 1] = previous_counter
if previous_counter > longest:
longest = previous_counter
counter = np.array(counter)
# we return the idx of the first element of the longest common sequence in the left sequence
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
return index_left, index_right, longest
def _find_longest_common_sequence(sequences, tokenizer):
# TODO Use a faster algorithm this can probably be done in O(n)
# using suffix array.
@ -181,7 +298,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
super().__init__(**kwargs)
self.feature_extractor = feature_extractor
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
self.type = "seq2seq"
elif (
feature_extractor._processor_class
@ -266,7 +385,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if ignore_warning is not None:
preprocess_params["ignore_warning"] = ignore_warning
forward_params = {"generate_kwargs": {}}
forward_params = defaultdict(dict)
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
@ -282,6 +401,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
# force whisper to output timestamp tokens, which means we need
# to set this variable to prevent `no_timestamp_token` to be
# used in the decoder.
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
return preprocess_params, forward_params, postprocess_params
@ -313,6 +439,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
extra = inputs
@ -369,7 +497,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# make sure that
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
):
yield item
else:
@ -409,14 +537,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)
out = {"tokens": tokens}
elif self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
tokens = self.model.generate(
input_features=model_inputs.pop("input_features"),
logits_processor=[WhisperTimeStampLogitsProcessor()],
**generate_kwargs,
)
out = {"tokens": tokens}
if stride is not None:
out["stride"] = stride
else:
stride = model_inputs.pop("stride", None)
@ -447,9 +583,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
optional = {}
if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
@ -465,12 +603,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# This won't work with left padding (which doesn't exist right now)
right_n = total_n - right
items = items[:, left:right_n]
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
# Whisper needs the stride data
items = [items, stride]
final_items.append(items)
if stride and self.type == "seq2seq":
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
items = _find_longest_common_sequence(final_items, self.tokenizer)
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
items = _find_timestamp_sequence(
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
)
else:
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if self.type == "ctc_with_lm":
if decoder_kwargs is None:
decoder_kwargs = {}
@ -483,24 +629,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
word_offsets = []
for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else:
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
if return_timestamps:
char_offsets = self.tokenizer.decode(
if return_timestamps and self.type == "seq2seq_whisper":
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
"offsets"
]
elif return_timestamps:
offsets = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
)["char_offsets"]
if return_timestamps == "word":
word_offsets = self.tokenizer._get_word_offsets(
char_offsets, self.tokenizer.replace_word_delimiter_char
)
offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
if return_timestamps:
if return_timestamps == "word":
offsets = word_offsets
else:
offsets = char_offsets
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
chunks = []
for item in offsets:
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
@ -511,6 +654,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
elif return_timestamps and self.type == "seq2seq_whisper":
optional["chunks"] = offsets
extra = defaultdict(list)
for output in model_outputs:

View File

@ -20,6 +20,8 @@ import os
import tempfile
import unittest
import numpy as np
from transformers import WhisperConfig
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property
@ -44,6 +46,7 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
@ -1030,7 +1033,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow
def test_tiny_en_batched_generation(self):
torch_device = "cuda"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
@ -1067,3 +1069,43 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)]
timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))]
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
# fmt: off
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404])
# fmt: on
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
# fmt: off
EXPECTED_TRANSCRIPT = [
{
'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
'offsets': [
{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)},
{'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)},
{'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)},
{'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)},
{'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)},
{'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)}
]
}
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

View File

@ -227,3 +227,71 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)
def test_offset_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# fmt: off
INPUT_TOKENS = [
50258, 50259, 50359, 50364, 441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724, 50724, 366,
382, 4048, 382, 257, 361, 18459, 13065, 13, 2221, 13,
7145, 74, 325, 38756, 311, 29822, 7563, 412, 472, 709,
294, 264, 51122, 51122, 912, 636, 300, 2221, 13, 2741,
5767, 1143, 281, 7319, 702, 7798, 13, 400, 2221, 13,
2619, 4004, 811, 2709, 702, 51449, 51449, 50257
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(
output,
[
{
"text": (
" Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
),
"timestamp": (0.0, 7.2),
},
{
"text": (
" are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the"
),
"timestamp": (7.2, 15.16),
},
{
"text": " same way that Mr. Carker used to flash his teeth. And Mr. John Colier gives his",
"timestamp": (15.16, 21.7),
},
],
)
# test a single sequence with timestamps
# fmt: off
INPUT_TOKENS = [
50364, 441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(
output[0],
{
"text": " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles",
"timestamp": (0.0, 7.2),
},
)
# test a sequence without a single timestamps
# fmt: off
INPUT_TOKENS = [
441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

View File

@ -23,6 +23,7 @@ from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
@ -31,7 +32,7 @@ from transformers import (
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import chunk_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
is_torch_available,
nested_simplify,
@ -87,7 +88,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
elif "Whisper" in speech_recognizer.model.__class__.__name__:
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
# Non CTC models cannot use striding.
with self.assertRaises(ValueError):
@ -117,6 +120,18 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
elif "Whisper" in speech_recognizer.model.__class__.__name__:
outputs = speech_recognizer(audio, return_timestamps=True)
self.assertIsInstance(outputs["chunks"], list)
nb_chunks = len(outputs["chunks"])
self.assertGreaterThan(nb_chunks, 0)
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(nb_chunks)],
},
)
else:
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
@ -142,7 +157,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, {"text": "(Applaudissements)"})
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
):
_ = speech_recognizer(waveform, return_timestamps="char")
@slow
@ -290,6 +307,280 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
@slow
def test_find_longest_common_subsequence(self):
max_source_positions = 1500
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
previous_sequence = [[51492, 406, 3163, 1953, 466, 13, 51612, 51612]]
self.assertEqual(
processor.decode(previous_sequence[0], output_offsets=True),
{
"text": " not worth thinking about.",
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
},
)
# Merge when the previous sequence is a suffix of the next sequence
# fmt: off
next_sequences_1 = [
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
]
# fmt: on
self.assertEqual(
processor.decode(next_sequences_1[0], output_offsets=True),
{
"text": (
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
" small, sharp blow high on his chest.<|endoftext|>"
),
"offsets": [
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (5.0, 9.4),
},
],
},
)
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_1, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 51739, 51739, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 27.5)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (27.5, 31.900000000000002),
},
],
},
)
# Merge when the sequence is in the middle of the 1st next sequence
# fmt: off
next_sequences_2 = [
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
]
# fmt: on
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_2, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on"
" his chest."
),
"timestamp": (22.56, 31.900000000000002),
},
],
},
)
# Merge when the previous sequence is not included in the current sequence
# fmt: off
next_sequences_3 = [[50364, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50584, 50257]]
# fmt: on
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51832],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.96, 29.36),
},
],
},
)
# last case is when the sequence is not in the first next predicted start and end of timestamp
# fmt: off
next_sequences_3 = [
[50364, 2812, 9836, 14783, 390, 51492, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
]
# fmt: on
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 53112, 53112, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 53332],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.96, 29.36),
},
],
},
)
@slow
@require_torch
def test_whisper_timestamp_prediction(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
array = np.concatenate(
[ds[40]["audio"]["array"], ds[41]["audio"]["array"], ds[42]["audio"]["array"], ds[43]["audio"]["array"]]
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(ds[40]["audio"])
self.assertDictEqual(
output,
{
"text": " A man said to the universe, Sir, I exist.",
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
},
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
output,
{
"chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
{
"text": (
" Sweat covered Brion's body, trickling into the "
"tight-loan cloth that was the only garment he wore, the "
"cut"
),
"timestamp": (5.5, 11.94),
},
{
"text": (
" on his chest still dripping blood, the ache of his "
"overstrained eyes, even the soaring arena around him "
"with"
),
"timestamp": (11.94, 19.6),
},
{
"text": " the thousands of spectators, retrievality is not worth thinking about.",
"timestamp": (19.6, 24.98),
},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.98, 30.98),
},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
"body, trickling into the tight-loan cloth that was the only garment "
"he wore, the cut on his chest still dripping blood, the ache of his "
"overstrained eyes, even the soaring arena around him with the "
"thousands of spectators, retrievality is not worth thinking about. "
"His instant panic was followed by a small, sharp blow high on his "
"chest."
),
},
)
output = pipe(array)
self.assertDictEqual(
output,
{
"chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
{
"text": (
" Sweat covered Brion's body, trickling into the "
"tight-loan cloth that was the only garment"
),
"timestamp": (5.5, 10.18),
},
{"text": " he wore.", "timestamp": (10.18, 11.68)},
{"text": " The cut on his chest still dripping blood.", "timestamp": (11.68, 14.92)},
{"text": " The ache of his overstrained eyes.", "timestamp": (14.92, 17.6)},
{
"text": (
" Even the soaring arena around him with the thousands of spectators were trivialities"
),
"timestamp": (17.6, 22.56),
},
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
"body, trickling into the tight-loan cloth that was the only garment "
"he wore. The cut on his chest still dripping blood. The ache of his "
"overstrained eyes. Even the soaring arena around him with the "
"thousands of spectators were trivialities not worth thinking about."
),
},
)
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
@ -724,22 +1015,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
def test_chunk_iterator(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = torch.arange(100).long()
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0, ratio))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
# two chunks no stride
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
# two chunks incomplete last
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
@ -750,7 +1041,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
# This test is specifically crafted to trigger a bug if next chunk
# would be ignored by the fact that all the data would be
# contained in the strided left data.
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5, ratio))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
@ -763,20 +1054,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
@ -785,7 +1076,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5, ratio))
self.assertEqual(len(outs), 5)
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])