mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Whisper Timestamp processor and prediction (#20620)
* add draft logit processor * add template functions * update timesapmt processor parameters * draft script * simplify code * cleanup * fixup and clean * update pipeline * style * clean up previous idea * add tokenization utils * update tokenizer and asr output * fit whisper type * style and update test * clean test * style test * update tests * update error test * udpate code (not based on review yet) * update tokenization * update asr pipeline * update code * cleanup and update test * fmt * remove text verificatino * cleanup * cleanup * add model test * update tests * update code add docstring * update code and add docstring * fix pipeline tests * add draft logit processor add template functions update timesapmt processor parameters draft script simplify code cleanup fixup and clean update pipeline style clean up previous idea add tokenization utils update tokenizer and asr output fit whisper type style and update test clean test style test update tests update error test udpate code (not based on review yet) update tokenization update asr pipeline update code cleanup and update test fmt remove text verificatino cleanup cleanup add model test update tests update code add docstring update code and add docstring fix pipeline tests * Small update. * Fixup. * Tmp. * More support. * Making `forced_decoder_ids` non mandatory for users to set. * update and fix first bug * properly process sequence right after merge if last * tofo * allow list inputs + compute begin index better * start adding tests * add the 3 edge cases * style * format sequences * fixup * update * update * style * test passes, edge cases should be good * update last value * remove Trie * update tests and expec ted values * handle bigger chunk_length * clean tests a bit * refactor chunk iter and clean pipeline * update tests * style * refactor chunk iter and clean pipeline * upade * resolve comments * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * take stride right into account * update test expected values * Update code based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
25ddd91b24
commit
bb300ac686
@ -801,3 +801,67 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
|
|||||||
scores[:, :] = -float("inf")
|
scores[:, :] = -float("inf")
|
||||||
scores[:, current_token] = 0
|
scores[:, current_token] = 0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
|
||||||
|
probs to `inf` so that they are sampled at their corresponding index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_index (`int`, *optional*, defaults to 5 ):
|
||||||
|
This indicates to the processor where the first tokens are generated. This is used to differentiate between
|
||||||
|
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the
|
||||||
|
`prompt` tokens are the first 4 tokens.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 50257):
|
||||||
|
The id of the *end-of-sequence* token.
|
||||||
|
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
|
||||||
|
The id of the `"<|notimestamps|>"` token.
|
||||||
|
max_initial_timestamp (`int`, *optional*, defaults to 1):
|
||||||
|
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
|
||||||
|
timestamps that are too far in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
begin_index=5,
|
||||||
|
eos_token_id=50257,
|
||||||
|
no_timestamps_token_id=50363,
|
||||||
|
max_initial_timestamp=1,
|
||||||
|
):
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.no_timestamps_token_id = no_timestamps_token_id
|
||||||
|
self.timestamp_begin = no_timestamps_token_id + 1
|
||||||
|
self.begin_index = begin_index
|
||||||
|
self.max_initial_timestamp_index = max_initial_timestamp
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||||
|
scores[:, self.no_timestamps_token_id] = -float("inf")
|
||||||
|
|
||||||
|
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
|
||||||
|
for k in range(input_ids.shape[0]):
|
||||||
|
seq = [t for t in input_ids[k, self.begin_index :].tolist()]
|
||||||
|
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
|
||||||
|
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
|
||||||
|
|
||||||
|
if last_was_timestamp:
|
||||||
|
if penultimate_was_timestamp: # has to be non-timestamp
|
||||||
|
scores[k, self.timestamp_begin :] = -float("inf")
|
||||||
|
else: # cannot be normal text tokens
|
||||||
|
scores[k, : self.eos_token_id] = -float("inf")
|
||||||
|
|
||||||
|
# apply the `max_initial_timestamp` option
|
||||||
|
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
|
||||||
|
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
|
||||||
|
scores[:, last_allowed + 1 :] = -float("inf")
|
||||||
|
|
||||||
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
|
logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1)
|
||||||
|
for k in range(input_ids.shape[0]):
|
||||||
|
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
|
||||||
|
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
|
||||||
|
if timestamp_logprob > max_text_token_logprob:
|
||||||
|
scores[k, : self.timestamp_begin] = -float("inf")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
@ -17,6 +17,8 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||||
@ -488,6 +490,91 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||||
return normalizer(text)
|
return normalizer(text)
|
||||||
|
|
||||||
|
def _compute_offsets(self, token_ids, time_precision=0.02):
|
||||||
|
"""
|
||||||
|
Compute offsets for a given tokenized input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
time_precision (`float`, `optional`, defaults to 0.02):
|
||||||
|
The time ratio to convert from token to time.
|
||||||
|
"""
|
||||||
|
offsets = []
|
||||||
|
token_ids = np.array(token_ids)
|
||||||
|
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
|
||||||
|
raise ValueError("Can only process a single input at a time")
|
||||||
|
timestamp_begin = self.all_special_ids[-1] + 1
|
||||||
|
timestamp_tokens = token_ids >= timestamp_begin
|
||||||
|
|
||||||
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||||
|
if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1:
|
||||||
|
# either there are no timestamps or there are no consecutive ones
|
||||||
|
return []
|
||||||
|
elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive:
|
||||||
|
# we add the final timestamp if it is not already in the list
|
||||||
|
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
||||||
|
|
||||||
|
last_slice = np.where(timestamp_tokens)[0][0]
|
||||||
|
for current_slice in consecutive:
|
||||||
|
sliced_tokens = token_ids[last_slice:current_slice]
|
||||||
|
if len(sliced_tokens) > 1:
|
||||||
|
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||||
|
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||||
|
offsets.append(
|
||||||
|
{
|
||||||
|
"text": self._decode(sliced_tokens),
|
||||||
|
"timestamp": (
|
||||||
|
start_timestamp_position * time_precision,
|
||||||
|
end_timestamp_position * time_precision,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
last_slice = current_slice
|
||||||
|
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
output_offsets: bool = False,
|
||||||
|
time_precision=0.02,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||||
|
tokens and clean up tokenization spaces.
|
||||||
|
|
||||||
|
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The decoded sentence.
|
||||||
|
"""
|
||||||
|
text = super().decode(
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# retrieve offsets
|
||||||
|
if output_offsets:
|
||||||
|
offsets = None
|
||||||
|
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
|
||||||
|
return {"text": text, "offsets": offsets}
|
||||||
|
return text
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
|
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -31,6 +31,8 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +56,7 @@ def rescale_stride(stride, ratio):
|
|||||||
return new_strides
|
return new_strides
|
||||||
|
|
||||||
|
|
||||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
|
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
|
||||||
inputs_len = inputs.shape[0]
|
inputs_len = inputs.shape[0]
|
||||||
step = chunk_len - stride_left - stride_right
|
step = chunk_len - stride_left - stride_right
|
||||||
for i in range(0, inputs_len, step):
|
for i in range(0, inputs_len, step):
|
||||||
@ -66,20 +68,135 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
|||||||
_stride_left = 0 if i == 0 else stride_left
|
_stride_left = 0 if i == 0 else stride_left
|
||||||
is_last = i + step + stride_left >= inputs_len
|
is_last = i + step + stride_left >= inputs_len
|
||||||
_stride_right = 0 if is_last else stride_right
|
_stride_right = 0 if is_last else stride_right
|
||||||
|
|
||||||
if "input_features" in processed:
|
|
||||||
processed_len = processed["input_features"].shape[-1]
|
|
||||||
elif "input_values" in processed:
|
|
||||||
processed_len = processed["input_values"].shape[-1]
|
|
||||||
chunk_len = chunk.shape[0]
|
chunk_len = chunk.shape[0]
|
||||||
stride = (chunk_len, _stride_left, _stride_right)
|
stride = (chunk_len, _stride_left, _stride_right)
|
||||||
if processed_len != chunk.shape[-1]:
|
if ratio != 1:
|
||||||
ratio = processed_len / chunk_len
|
|
||||||
stride = rescale_stride([stride], ratio)[0]
|
stride = rescale_stride([stride], ratio)[0]
|
||||||
if chunk.shape[0] > _stride_left:
|
if chunk.shape[0] > _stride_left:
|
||||||
yield {"is_last": is_last, "stride": stride, **processed}
|
yield {"is_last": is_last, "stride": stride, **processed}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||||
|
"""
|
||||||
|
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||||
|
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||||
|
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||||
|
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||||
|
properly compute the final `offset`.
|
||||||
|
"""
|
||||||
|
# index of the first timestamp token
|
||||||
|
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||||
|
items = []
|
||||||
|
# approximation of the token to time ratio : ~0.2seconds
|
||||||
|
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||||
|
time = 0
|
||||||
|
actual_offset = 0
|
||||||
|
for seq_idx, item in enumerate(sequences):
|
||||||
|
sequence, stride = item
|
||||||
|
if isinstance(sequence, list):
|
||||||
|
sequence = np.array(sequence)
|
||||||
|
chunk_len, stride_left, stride_right = stride
|
||||||
|
sequence = sequence.squeeze(0)
|
||||||
|
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||||
|
begin_idx = np.where(sequence == timestamp_begin)[0].item() if timestamp_begin in sequence else 0
|
||||||
|
sequence = sequence[begin_idx:]
|
||||||
|
|
||||||
|
if seq_idx != 0:
|
||||||
|
time -= stride_left + stride_right
|
||||||
|
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||||
|
timestamp_tokens = np.where(sequence >= timestamp_begin)[0][1::2]
|
||||||
|
if len(timestamp_tokens) >= 1:
|
||||||
|
# if a big chunk lenght is used, we need to check all of the previous items
|
||||||
|
best_match = 0
|
||||||
|
sliced_sequence = []
|
||||||
|
for idx, previous_sequence in enumerate(reversed(items)):
|
||||||
|
previous_tokens = previous_sequence[1:-1]
|
||||||
|
if len(previous_tokens) > 0:
|
||||||
|
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||||
|
sequence, previous_tokens
|
||||||
|
)
|
||||||
|
# don't do anything if only 1 token was matched
|
||||||
|
if match_length > 1 and match_length > best_match:
|
||||||
|
best_match = match_length
|
||||||
|
best_idx = idx
|
||||||
|
end_of_curr_sequence_idx = (
|
||||||
|
np.where(sequence[index_left:] >= timestamp_begin)[0][0] + 1 + index_left
|
||||||
|
)
|
||||||
|
sliced_sequence = sequence[index_left:end_of_curr_sequence_idx]
|
||||||
|
# if all the tokens are matched, suffix
|
||||||
|
if index_left == 0 and match_length == len(previous_tokens):
|
||||||
|
sliced_sequence[-1] = previous_sequence[-1]
|
||||||
|
# if part of the previous sequence is not taken
|
||||||
|
elif index_left > 0:
|
||||||
|
# let's insert the missing part of the previous sequence
|
||||||
|
sliced_sequence = np.insert(sliced_sequence, 0, previous_sequence[: index_right + 1])
|
||||||
|
sliced_sequence[-1] += offset
|
||||||
|
if len(sliced_sequence) > 0:
|
||||||
|
items[len(items) - best_idx - 1] = sliced_sequence
|
||||||
|
items = items[: len(items) - best_idx]
|
||||||
|
sequence = sequence[end_of_curr_sequence_idx:]
|
||||||
|
actual_offset = items[-1][-1] - timestamp_begin
|
||||||
|
|
||||||
|
timestamp_tokens = sequence >= timestamp_begin
|
||||||
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||||
|
|
||||||
|
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||||
|
last_slice = 0
|
||||||
|
# take the last timestamp of the previous chunk
|
||||||
|
for current_slice in consecutive:
|
||||||
|
sliced_tokens = sequence[last_slice:current_slice]
|
||||||
|
# set correct timestamps
|
||||||
|
sliced_tokens[0] += actual_offset
|
||||||
|
sliced_tokens[-1] += actual_offset
|
||||||
|
items.append(sliced_tokens) # correct final sequence
|
||||||
|
last_slice = current_slice
|
||||||
|
# check if we have a non consecutive timestamp at the end
|
||||||
|
if np.where(timestamp_tokens)[0][-1] != current_slice:
|
||||||
|
# offset = items[-1][-1] if len(items) > 0 else timestamp_begin
|
||||||
|
sliced_tokens = sequence[current_slice : np.where(timestamp_tokens)[0][-1] + 1]
|
||||||
|
sliced_tokens[0] += actual_offset
|
||||||
|
sliced_tokens[-1] += actual_offset
|
||||||
|
items.append(sliced_tokens)
|
||||||
|
else:
|
||||||
|
timestamps = sequence[timestamp_tokens.nonzero()[0].flatten()]
|
||||||
|
if len(timestamps) > 0 and timestamps[-1].item() != timestamp_begin:
|
||||||
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
|
last_idx = np.argwhere(sequence == timestamps[-1])[0][0]
|
||||||
|
sliced_sequence = sequence[: last_idx + 1]
|
||||||
|
duration = sliced_sequence[-1] - sliced_sequence[0]
|
||||||
|
# We need to discard the previous timing information
|
||||||
|
sliced_sequence[0] = items[-1][-1]
|
||||||
|
sliced_sequence[-1] = items[-1][-1] + duration
|
||||||
|
items.append(sliced_sequence)
|
||||||
|
# The beginning time of the next chunk
|
||||||
|
time += chunk_len
|
||||||
|
result = []
|
||||||
|
for i in range(len(items)):
|
||||||
|
result += items[i].tolist()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
||||||
|
seq_len_left = len(sequence_left)
|
||||||
|
seq_len_right = len(sequence_right)
|
||||||
|
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
|
||||||
|
longest = 0
|
||||||
|
for i in range(seq_len_left):
|
||||||
|
for j in range(seq_len_right):
|
||||||
|
if sequence_left[i] == sequence_right[j]:
|
||||||
|
previous_counter = counter[i][j] + 1
|
||||||
|
counter[i + 1][j + 1] = previous_counter
|
||||||
|
if previous_counter > longest:
|
||||||
|
longest = previous_counter
|
||||||
|
|
||||||
|
counter = np.array(counter)
|
||||||
|
# we return the idx of the first element of the longest common sequence in the left sequence
|
||||||
|
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
|
||||||
|
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
|
||||||
|
return index_left, index_right, longest
|
||||||
|
|
||||||
|
|
||||||
def _find_longest_common_sequence(sequences, tokenizer):
|
def _find_longest_common_sequence(sequences, tokenizer):
|
||||||
# TODO Use a faster algorithm this can probably be done in O(n)
|
# TODO Use a faster algorithm this can probably be done in O(n)
|
||||||
# using suffix array.
|
# using suffix array.
|
||||||
@ -181,7 +298,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.feature_extractor = feature_extractor
|
self.feature_extractor = feature_extractor
|
||||||
|
|
||||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
if self.model.config.model_type == "whisper":
|
||||||
|
self.type = "seq2seq_whisper"
|
||||||
|
elif self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||||
self.type = "seq2seq"
|
self.type = "seq2seq"
|
||||||
elif (
|
elif (
|
||||||
feature_extractor._processor_class
|
feature_extractor._processor_class
|
||||||
@ -266,7 +385,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if ignore_warning is not None:
|
if ignore_warning is not None:
|
||||||
preprocess_params["ignore_warning"] = ignore_warning
|
preprocess_params["ignore_warning"] = ignore_warning
|
||||||
|
|
||||||
forward_params = {"generate_kwargs": {}}
|
forward_params = defaultdict(dict)
|
||||||
if max_new_tokens is not None:
|
if max_new_tokens is not None:
|
||||||
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||||
if generate_kwargs is not None:
|
if generate_kwargs is not None:
|
||||||
@ -282,6 +401,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||||
if return_timestamps is not None:
|
if return_timestamps is not None:
|
||||||
postprocess_params["return_timestamps"] = return_timestamps
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
|
if self.model.config.model_type == "whisper":
|
||||||
|
# Whisper is highly specific, if we want timestamps, we need to
|
||||||
|
# force whisper to output timestamp tokens, which means we need
|
||||||
|
# to set this variable to prevent `no_timestamp_token` to be
|
||||||
|
# used in the decoder.
|
||||||
|
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
|
||||||
|
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
@ -313,6 +439,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
_inputs = inputs.pop("raw", None)
|
_inputs = inputs.pop("raw", None)
|
||||||
if _inputs is None:
|
if _inputs is None:
|
||||||
|
# Remove path which will not be used from `datasets`.
|
||||||
|
inputs.pop("path", None)
|
||||||
_inputs = inputs.pop("array", None)
|
_inputs = inputs.pop("array", None)
|
||||||
in_sampling_rate = inputs.pop("sampling_rate")
|
in_sampling_rate = inputs.pop("sampling_rate")
|
||||||
extra = inputs
|
extra = inputs
|
||||||
@ -369,7 +497,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
# make sure that
|
# make sure that
|
||||||
for item in chunk_iter(
|
for item in chunk_iter(
|
||||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
else:
|
else:
|
||||||
@ -409,14 +537,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# `generate` magic to create the mask automatically won't work, we basically need to help
|
# `generate` magic to create the mask automatically won't work, we basically need to help
|
||||||
# it here.
|
# it here.
|
||||||
attention_mask = model_inputs.pop("attention_mask", None)
|
attention_mask = model_inputs.pop("attention_mask", None)
|
||||||
|
|
||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = {"tokens": tokens}
|
out = {"tokens": tokens}
|
||||||
|
elif self.type == "seq2seq_whisper":
|
||||||
|
stride = model_inputs.pop("stride", None)
|
||||||
|
tokens = self.model.generate(
|
||||||
|
input_features=model_inputs.pop("input_features"),
|
||||||
|
logits_processor=[WhisperTimeStampLogitsProcessor()],
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
out = {"tokens": tokens}
|
||||||
|
if stride is not None:
|
||||||
|
out["stride"] = stride
|
||||||
|
|
||||||
else:
|
else:
|
||||||
stride = model_inputs.pop("stride", None)
|
stride = model_inputs.pop("stride", None)
|
||||||
@ -447,9 +583,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
optional = {}
|
optional = {}
|
||||||
|
|
||||||
if return_timestamps and self.type == "seq2seq":
|
if return_timestamps and self.type == "seq2seq":
|
||||||
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
|
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
||||||
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
||||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
|
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
|
||||||
|
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
||||||
|
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
||||||
|
|
||||||
final_items = []
|
final_items = []
|
||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
@ -465,12 +603,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# This won't work with left padding (which doesn't exist right now)
|
# This won't work with left padding (which doesn't exist right now)
|
||||||
right_n = total_n - right
|
right_n = total_n - right
|
||||||
items = items[:, left:right_n]
|
items = items[:, left:right_n]
|
||||||
|
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
|
||||||
|
# Whisper needs the stride data
|
||||||
|
items = [items, stride]
|
||||||
final_items.append(items)
|
final_items.append(items)
|
||||||
if stride and self.type == "seq2seq":
|
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
|
||||||
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
||||||
|
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
|
||||||
|
items = _find_timestamp_sequence(
|
||||||
|
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
items = np.concatenate(final_items, axis=1)
|
items = np.concatenate(final_items, axis=1)
|
||||||
items = items.squeeze(0)
|
items = items.squeeze(0)
|
||||||
|
|
||||||
if self.type == "ctc_with_lm":
|
if self.type == "ctc_with_lm":
|
||||||
if decoder_kwargs is None:
|
if decoder_kwargs is None:
|
||||||
decoder_kwargs = {}
|
decoder_kwargs = {}
|
||||||
@ -483,24 +629,21 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
word_offsets = []
|
word_offsets = []
|
||||||
for word, (start_offset, end_offset) in chunk_offset:
|
for word, (start_offset, end_offset) in chunk_offset:
|
||||||
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
skip_special_tokens = self.type != "ctc"
|
skip_special_tokens = self.type != "ctc"
|
||||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||||
if return_timestamps:
|
if return_timestamps and self.type == "seq2seq_whisper":
|
||||||
char_offsets = self.tokenizer.decode(
|
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
|
||||||
|
"offsets"
|
||||||
|
]
|
||||||
|
elif return_timestamps:
|
||||||
|
offsets = self.tokenizer.decode(
|
||||||
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
||||||
)["char_offsets"]
|
)["char_offsets"]
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
word_offsets = self.tokenizer._get_word_offsets(
|
offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
|
||||||
char_offsets, self.tokenizer.replace_word_delimiter_char
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_timestamps:
|
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
|
||||||
if return_timestamps == "word":
|
|
||||||
offsets = word_offsets
|
|
||||||
else:
|
|
||||||
offsets = char_offsets
|
|
||||||
chunks = []
|
chunks = []
|
||||||
for item in offsets:
|
for item in offsets:
|
||||||
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
|
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
|
||||||
@ -511,6 +654,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
||||||
optional["chunks"] = chunks
|
optional["chunks"] = chunks
|
||||||
|
elif return_timestamps and self.type == "seq2seq_whisper":
|
||||||
|
optional["chunks"] = offsets
|
||||||
|
|
||||||
extra = defaultdict(list)
|
extra = defaultdict(list)
|
||||||
for output in model_outputs:
|
for output in model_outputs:
|
||||||
|
@ -20,6 +20,8 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import WhisperConfig
|
from transformers import WhisperConfig
|
||||||
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
|
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
@ -44,6 +46,7 @@ if is_torch_available():
|
|||||||
WhisperProcessor,
|
WhisperProcessor,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||||
|
|
||||||
|
|
||||||
@ -1030,7 +1033,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_en_batched_generation(self):
|
def test_tiny_en_batched_generation(self):
|
||||||
torch_device = "cuda"
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
@ -1067,3 +1069,43 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tiny_timestamp_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = np.concatenate(self._load_datasamples(4))
|
||||||
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)]
|
||||||
|
timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))]
|
||||||
|
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_TRANSCRIPT = [
|
||||||
|
{
|
||||||
|
'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||||
|
'offsets': [
|
||||||
|
{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)},
|
||||||
|
{'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)},
|
||||||
|
{'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)},
|
||||||
|
{'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)},
|
||||||
|
{'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)},
|
||||||
|
{'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
@ -227,3 +227,71 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||||
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||||
self.assertListEqual(batch, transcription)
|
self.assertListEqual(batch, transcription)
|
||||||
|
|
||||||
|
def test_offset_decoding(self):
|
||||||
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
||||||
|
# fmt: off
|
||||||
|
INPUT_TOKENS = [
|
||||||
|
50258, 50259, 50359, 50364, 441, 1857, 4174, 11, 5242, 366,
|
||||||
|
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||||
|
293, 25730, 311, 454, 34152, 4496, 904, 50724, 50724, 366,
|
||||||
|
382, 4048, 382, 257, 361, 18459, 13065, 13, 2221, 13,
|
||||||
|
7145, 74, 325, 38756, 311, 29822, 7563, 412, 472, 709,
|
||||||
|
294, 264, 51122, 51122, 912, 636, 300, 2221, 13, 2741,
|
||||||
|
5767, 1143, 281, 7319, 702, 7798, 13, 400, 2221, 13,
|
||||||
|
2619, 4004, 811, 2709, 702, 51449, 51449, 50257
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
|
||||||
|
),
|
||||||
|
"timestamp": (0.0, 7.2),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the"
|
||||||
|
),
|
||||||
|
"timestamp": (7.2, 15.16),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " same way that Mr. Carker used to flash his teeth. And Mr. John Colier gives his",
|
||||||
|
"timestamp": (15.16, 21.7),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# test a single sequence with timestamps
|
||||||
|
# fmt: off
|
||||||
|
INPUT_TOKENS = [
|
||||||
|
50364, 441, 1857, 4174, 11, 5242, 366,
|
||||||
|
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||||
|
293, 25730, 311, 454, 34152, 4496, 904, 50724
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||||
|
self.assertEqual(
|
||||||
|
output[0],
|
||||||
|
{
|
||||||
|
"text": " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles",
|
||||||
|
"timestamp": (0.0, 7.2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# test a sequence without a single timestamps
|
||||||
|
# fmt: off
|
||||||
|
INPUT_TOKENS = [
|
||||||
|
441, 1857, 4174, 11, 5242, 366,
|
||||||
|
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||||
|
293, 25730, 311, 454, 34152, 4496, 904, 50724
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||||
|
self.assertEqual(output, [])
|
||||||
|
@ -23,6 +23,7 @@ from transformers import (
|
|||||||
MODEL_FOR_CTC_MAPPING,
|
MODEL_FOR_CTC_MAPPING,
|
||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Speech2TextForConditionalGeneration,
|
Speech2TextForConditionalGeneration,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
@ -31,7 +32,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||||
from transformers.pipelines.automatic_speech_recognition import chunk_iter
|
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
@ -87,7 +88,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
if speech_recognizer.type == "ctc":
|
if speech_recognizer.type == "ctc":
|
||||||
outputs = speech_recognizer(audio)
|
outputs = speech_recognizer(audio)
|
||||||
self.assertEqual(outputs, {"text": ANY(str)})
|
self.assertEqual(outputs, {"text": ANY(str)})
|
||||||
|
elif "Whisper" in speech_recognizer.model.__class__.__name__:
|
||||||
|
outputs = speech_recognizer(audio)
|
||||||
|
self.assertEqual(outputs, {"text": ANY(str)})
|
||||||
else:
|
else:
|
||||||
# Non CTC models cannot use striding.
|
# Non CTC models cannot use striding.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -117,6 +120,18 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
|
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif "Whisper" in speech_recognizer.model.__class__.__name__:
|
||||||
|
outputs = speech_recognizer(audio, return_timestamps=True)
|
||||||
|
self.assertIsInstance(outputs["chunks"], list)
|
||||||
|
nb_chunks = len(outputs["chunks"])
|
||||||
|
self.assertGreaterThan(nb_chunks, 0)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
{
|
||||||
|
"text": ANY(str),
|
||||||
|
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(nb_chunks)],
|
||||||
|
},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Non CTC models cannot use return_timestamps
|
# Non CTC models cannot use return_timestamps
|
||||||
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
||||||
@ -142,7 +157,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
|
||||||
# Non CTC models cannot use return_timestamps
|
# Non CTC models cannot use return_timestamps
|
||||||
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
||||||
|
):
|
||||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@ -290,6 +307,280 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
|
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
|
||||||
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
|
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_find_longest_common_subsequence(self):
|
||||||
|
max_source_positions = 1500
|
||||||
|
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
|
||||||
|
previous_sequence = [[51492, 406, 3163, 1953, 466, 13, 51612, 51612]]
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(previous_sequence[0], output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": " not worth thinking about.",
|
||||||
|
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge when the previous sequence is a suffix of the next sequence
|
||||||
|
# fmt: off
|
||||||
|
next_sequences_1 = [
|
||||||
|
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(next_sequences_1[0], output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
|
||||||
|
" small, sharp blow high on his chest.<|endoftext|>"
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (5.0, 9.4),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
merge = _find_timestamp_sequence(
|
||||||
|
[[previous_sequence, (3000, 0, 0)], [next_sequences_1, (3000, 750, 0)]],
|
||||||
|
processor.tokenizer,
|
||||||
|
processor.feature_extractor,
|
||||||
|
max_source_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
merge,
|
||||||
|
[51492, 406, 3163, 1953, 466, 13, 51739, 51739, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(merge, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||||
|
" chest."
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{"text": " not worth thinking about.", "timestamp": (22.56, 27.5)},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (27.5, 31.900000000000002),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge when the sequence is in the middle of the 1st next sequence
|
||||||
|
# fmt: off
|
||||||
|
next_sequences_2 = [
|
||||||
|
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||||
|
merge = _find_timestamp_sequence(
|
||||||
|
[[previous_sequence, (3000, 0, 0)], [next_sequences_2, (3000, 750, 0)]],
|
||||||
|
processor.tokenizer,
|
||||||
|
processor.feature_extractor,
|
||||||
|
max_source_positions,
|
||||||
|
)
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
merge,
|
||||||
|
[51492, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(merge, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||||
|
" chest."
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" not worth thinking about. His instant panic was followed by a small, sharp blow high on"
|
||||||
|
" his chest."
|
||||||
|
),
|
||||||
|
"timestamp": (22.56, 31.900000000000002),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge when the previous sequence is not included in the current sequence
|
||||||
|
# fmt: off
|
||||||
|
next_sequences_3 = [[50364, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50584, 50257]]
|
||||||
|
# fmt: on
|
||||||
|
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||||
|
merge = _find_timestamp_sequence(
|
||||||
|
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
|
||||||
|
processor.tokenizer,
|
||||||
|
processor.feature_extractor,
|
||||||
|
max_source_positions,
|
||||||
|
)
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
merge,
|
||||||
|
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51832],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(merge, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||||
|
" chest."
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (24.96, 29.36),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# last case is when the sequence is not in the first next predicted start and end of timestamp
|
||||||
|
# fmt: off
|
||||||
|
next_sequences_3 = [
|
||||||
|
[50364, 2812, 9836, 14783, 390, 51492, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
merge = _find_timestamp_sequence(
|
||||||
|
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
|
||||||
|
processor.tokenizer,
|
||||||
|
processor.feature_extractor,
|
||||||
|
max_source_positions,
|
||||||
|
)
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
merge,
|
||||||
|
[51492, 406, 3163, 1953, 466, 13, 53112, 53112, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 53332],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
processor.decode(merge, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||||
|
" chest."
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (24.96, 29.36),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_whisper_timestamp_prediction(self):
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
array = np.concatenate(
|
||||||
|
[ds[40]["audio"]["array"], ds[41]["audio"]["array"], ds[42]["audio"]["array"], ds[43]["audio"]["array"]]
|
||||||
|
)
|
||||||
|
pipe = pipeline(
|
||||||
|
model="openai/whisper-small",
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = pipe(ds[40]["audio"])
|
||||||
|
self.assertDictEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"text": " A man said to the universe, Sir, I exist.",
|
||||||
|
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pipe = pipeline(
|
||||||
|
model="openai/whisper-small",
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = pipe(array, chunk_length_s=10)
|
||||||
|
self.assertDictEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"chunks": [
|
||||||
|
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Sweat covered Brion's body, trickling into the "
|
||||||
|
"tight-loan cloth that was the only garment he wore, the "
|
||||||
|
"cut"
|
||||||
|
),
|
||||||
|
"timestamp": (5.5, 11.94),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" on his chest still dripping blood, the ache of his "
|
||||||
|
"overstrained eyes, even the soaring arena around him "
|
||||||
|
"with"
|
||||||
|
),
|
||||||
|
"timestamp": (11.94, 19.6),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " the thousands of spectators, retrievality is not worth thinking about.",
|
||||||
|
"timestamp": (19.6, 24.98),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (24.98, 30.98),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"text": (
|
||||||
|
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
|
||||||
|
"body, trickling into the tight-loan cloth that was the only garment "
|
||||||
|
"he wore, the cut on his chest still dripping blood, the ache of his "
|
||||||
|
"overstrained eyes, even the soaring arena around him with the "
|
||||||
|
"thousands of spectators, retrievality is not worth thinking about. "
|
||||||
|
"His instant panic was followed by a small, sharp blow high on his "
|
||||||
|
"chest."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
output = pipe(array)
|
||||||
|
self.assertDictEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"chunks": [
|
||||||
|
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Sweat covered Brion's body, trickling into the "
|
||||||
|
"tight-loan cloth that was the only garment"
|
||||||
|
),
|
||||||
|
"timestamp": (5.5, 10.18),
|
||||||
|
},
|
||||||
|
{"text": " he wore.", "timestamp": (10.18, 11.68)},
|
||||||
|
{"text": " The cut on his chest still dripping blood.", "timestamp": (11.68, 14.92)},
|
||||||
|
{"text": " The ache of his overstrained eyes.", "timestamp": (14.92, 17.6)},
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" Even the soaring arena around him with the thousands of spectators were trivialities"
|
||||||
|
),
|
||||||
|
"timestamp": (17.6, 22.56),
|
||||||
|
},
|
||||||
|
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||||
|
],
|
||||||
|
"text": (
|
||||||
|
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
|
||||||
|
"body, trickling into the tight-loan cloth that was the only garment "
|
||||||
|
"he wore. The cut on his chest still dripping blood. The ache of his "
|
||||||
|
"overstrained eyes. Even the soaring arena around him with the "
|
||||||
|
"thousands of spectators were trivialities not worth thinking about."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_speech_encoder_decoder(self):
|
def test_torch_speech_encoder_decoder(self):
|
||||||
@ -724,22 +1015,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
def test_chunk_iterator(self):
|
def test_chunk_iterator(self):
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
inputs = torch.arange(100).long()
|
inputs = torch.arange(100).long()
|
||||||
|
ratio = 1
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
|
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0, ratio))
|
||||||
self.assertEqual(len(outs), 1)
|
self.assertEqual(len(outs), 1)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||||
self.assertEqual([o["is_last"] for o in outs], [True])
|
self.assertEqual([o["is_last"] for o in outs], [True])
|
||||||
|
|
||||||
# two chunks no stride
|
# two chunks no stride
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
|
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0, ratio))
|
||||||
self.assertEqual(len(outs), 2)
|
self.assertEqual(len(outs), 2)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
|
||||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||||
|
|
||||||
# two chunks incomplete last
|
# two chunks incomplete last
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
|
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0, ratio))
|
||||||
self.assertEqual(len(outs), 2)
|
self.assertEqual(len(outs), 2)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
|
||||||
@ -750,7 +1041,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
# This test is specifically crafted to trigger a bug if next chunk
|
# This test is specifically crafted to trigger a bug if next chunk
|
||||||
# would be ignored by the fact that all the data would be
|
# would be ignored by the fact that all the data would be
|
||||||
# contained in the strided left data.
|
# contained in the strided left data.
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
|
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5, ratio))
|
||||||
self.assertEqual(len(outs), 1)
|
self.assertEqual(len(outs), 1)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||||
@ -763,20 +1054,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||||
"input_values"
|
"input_values"
|
||||||
]
|
]
|
||||||
|
ratio = 1
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
|
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10, ratio))
|
||||||
self.assertEqual(len(outs), 2)
|
self.assertEqual(len(outs), 2)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
|
||||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||||
|
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
|
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10, ratio))
|
||||||
self.assertEqual(len(outs), 2)
|
self.assertEqual(len(outs), 2)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
|
||||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||||
|
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
|
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0, ratio))
|
||||||
self.assertEqual(len(outs), 2)
|
self.assertEqual(len(outs), 2)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
|
||||||
@ -785,7 +1076,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||||
"input_values"
|
"input_values"
|
||||||
]
|
]
|
||||||
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
|
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5, ratio))
|
||||||
self.assertEqual(len(outs), 5)
|
self.assertEqual(len(outs), 5)
|
||||||
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
|
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
|
||||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])
|
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])
|
||||||
|
Loading…
Reference in New Issue
Block a user