mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
[Whisper] 🚨 Fix pipeline word timestamp: timestamp token is end of token time !!! (#36632)
* timestamp token is end of token time !!! * ensure correct alignment between tokens and timestamp tokens * ignore input tokens for DTW computation * use num_frames to avoid token timestamp hallucinations * token timestamps test updates ! * num_frames: deprecate and use attention_mask instead * avoid breaking change * fix the pipeline usage for chunk approach * make style * better logging * better logging * make style * update tests with correct values
This commit is contained in:
parent
9c8d3a70b8
commit
2b85b6ce19
@ -252,6 +252,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
||||||
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
||||||
return_token_timestamps (`bool`, *optional*, defaults to `None`):
|
return_token_timestamps (`bool`, *optional*, defaults to `None`):
|
||||||
|
Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
|
||||||
|
|
||||||
Whether or not to return the number of frames of the input raw_speech.
|
Whether or not to return the number of frames of the input raw_speech.
|
||||||
These num_frames can be used by the model to compute word level timestamps.
|
These num_frames can be used by the model to compute word level timestamps.
|
||||||
"""
|
"""
|
||||||
@ -327,6 +329,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
||||||
|
|
||||||
if return_token_timestamps is not None:
|
if return_token_timestamps is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"`return_token_timestamps` is deprecated for {self.__class__.__name__} and will be removed in Transformers v5. Use `return_attention_mask` instead, as the number of frames can be inferred from it."
|
||||||
|
)
|
||||||
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
|
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
|
||||||
|
|
||||||
if return_tensors is not None:
|
if return_tensors is not None:
|
||||||
|
@ -331,6 +331,11 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
|
num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
|
||||||
num_frames = np.repeat(num_frames, repeat_time)
|
num_frames = np.repeat(num_frames, repeat_time)
|
||||||
|
|
||||||
|
# let's ignore decoder_input_ids that can negatively impact the DTW while we know they have timestamps 0.0s
|
||||||
|
# (they are not taken into account for the DTW in OAI implementation)
|
||||||
|
if num_input_ids is not None:
|
||||||
|
weights = weights[:, :, num_input_ids:, :]
|
||||||
|
|
||||||
if num_frames is None or isinstance(num_frames, int):
|
if num_frames is None or isinstance(num_frames, int):
|
||||||
# Normalize and smoothen the weights.
|
# Normalize and smoothen the weights.
|
||||||
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
@ -360,7 +365,13 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
|
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
jump_times = time_indices[jumps] * time_precision
|
jump_times = time_indices[jumps] * time_precision
|
||||||
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
|
||||||
|
# each predicted token has a corresponding timestamp, expect the eos token for which we don't retrieve cross attentions
|
||||||
|
# 1. for decoder_input_ids, we set the timestamps to 0.0
|
||||||
|
# 2. for the eos token, we simply duplicate the timestamp of the last non-eos token
|
||||||
|
timestamps[batch_idx] = torch.cat(
|
||||||
|
[torch.zeros(num_input_ids), torch.tensor(jump_times), torch.tensor([jump_times[-1]])]
|
||||||
|
)
|
||||||
|
|
||||||
return timestamps
|
return timestamps
|
||||||
|
|
||||||
@ -632,7 +643,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
|
||||||
)
|
)
|
||||||
self._set_num_frames(
|
self._set_num_frames(
|
||||||
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
|
return_token_timestamps=return_token_timestamps,
|
||||||
|
generation_config=generation_config,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
self._set_thresholds_and_condition(
|
self._set_thresholds_and_condition(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
@ -810,10 +824,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
segment_input=segment_input,
|
segment_input=segment_input,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
cur_bsz=cur_bsz,
|
cur_bsz=cur_bsz,
|
||||||
batch_idx_map=batch_idx_map,
|
|
||||||
seek=seek,
|
seek=seek,
|
||||||
num_segment_frames=num_segment_frames,
|
batch_idx_map=batch_idx_map,
|
||||||
max_frames=max_frames,
|
|
||||||
temperatures=temperatures,
|
temperatures=temperatures,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
@ -928,10 +940,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
segment_input,
|
segment_input,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
cur_bsz,
|
cur_bsz,
|
||||||
batch_idx_map,
|
|
||||||
seek,
|
seek,
|
||||||
num_segment_frames,
|
batch_idx_map,
|
||||||
max_frames,
|
|
||||||
temperatures,
|
temperatures,
|
||||||
generation_config,
|
generation_config,
|
||||||
logits_processor,
|
logits_processor,
|
||||||
@ -1003,6 +1013,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
return_token_timestamps=return_token_timestamps,
|
return_token_timestamps=return_token_timestamps,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
is_shortform=is_shortform,
|
is_shortform=is_shortform,
|
||||||
|
seek=seek,
|
||||||
|
batch_idx_map=batch_idx_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cur_bsz < batch_size:
|
if cur_bsz < batch_size:
|
||||||
@ -1089,6 +1101,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
return_token_timestamps,
|
return_token_timestamps,
|
||||||
generation_config,
|
generation_config,
|
||||||
is_shortform,
|
is_shortform,
|
||||||
|
seek,
|
||||||
|
batch_idx_map,
|
||||||
):
|
):
|
||||||
# remove all previously passed decoder input ids
|
# remove all previously passed decoder input ids
|
||||||
# should happen only if it is the first generated segment
|
# should happen only if it is the first generated segment
|
||||||
@ -1098,7 +1112,11 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
return seek_outputs[:, start_idx:], seek_outputs
|
return seek_outputs[:, start_idx:], seek_outputs
|
||||||
|
|
||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||||
num_frames = getattr(generation_config, "num_frames", None)
|
num_frames = getattr(generation_config, "num_frames")
|
||||||
|
if num_frames is not None:
|
||||||
|
num_frames = num_frames - seek
|
||||||
|
num_frames = num_frames[batch_idx_map]
|
||||||
|
|
||||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||||
seek_outputs,
|
seek_outputs,
|
||||||
generation_config.alignment_heads,
|
generation_config.alignment_heads,
|
||||||
@ -1634,7 +1652,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
|
def _set_num_frames(return_token_timestamps, generation_config, attention_mask, kwargs):
|
||||||
if return_token_timestamps:
|
if return_token_timestamps:
|
||||||
if getattr(generation_config, "task", None) == "translate":
|
if getattr(generation_config, "task", None) == "translate":
|
||||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||||
@ -1643,7 +1661,24 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||||
)
|
)
|
||||||
generation_config.num_frames = kwargs.pop("num_frames", None)
|
if "num_frames" in kwargs:
|
||||||
|
generation_config.num_frames = kwargs.pop("num_frames")
|
||||||
|
if isinstance(generation_config.num_frames, torch.Tensor):
|
||||||
|
generation_config.num_frames = generation_config.num_frames.cpu()
|
||||||
|
else:
|
||||||
|
generation_config.num_frames = torch.tensor(generation_config.num_frames)
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"`num_frames` is deprecated and will be removed in Transformers v5. Use `attention_mask` instead, as it can be used to infer the number of frames. "
|
||||||
|
"You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True"
|
||||||
|
)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
generation_config.num_frames = attention_mask.sum(-1).cpu()
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"When setting `return_token_timestamps` to `True`, make sure to pass an `attention_mask` to get precise token-level timestamps. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
||||||
|
)
|
||||||
|
generation_config.num_frames = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_thresholds_and_condition(
|
def _set_thresholds_and_condition(
|
||||||
|
@ -1099,11 +1099,11 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
# merges later and decode into text.
|
# merges later and decode into text.
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
start_time = round(token_timestamps[i] + time_offset, 2)
|
if i == 0:
|
||||||
if i + 1 < len(token_timestamps):
|
start_time = round(0.0 + time_offset, 2)
|
||||||
end_time = round(token_timestamps[i + 1] + time_offset, 2)
|
|
||||||
else:
|
else:
|
||||||
end_time = None # should never happen
|
start_time = round(token_timestamps[i - 1] + time_offset, 2)
|
||||||
|
end_time = round(token_timestamps[i] + time_offset, 2)
|
||||||
current_token_timestamps.append((start_time, end_time))
|
current_token_timestamps.append((start_time, end_time))
|
||||||
|
|
||||||
if "stride" in output:
|
if "stride" in output:
|
||||||
|
@ -495,19 +495,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# custom processing for Whisper timestamps and word-level timestamps
|
# custom processing for Whisper timestamps and word-level timestamps
|
||||||
return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False)
|
return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False)
|
||||||
if return_timestamps and self.type == "seq2seq_whisper":
|
if return_timestamps and self.type == "seq2seq_whisper":
|
||||||
generate_kwargs["return_timestamps"] = return_timestamps
|
generate_kwargs["return_timestamps"] = bool(return_timestamps)
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
generate_kwargs["return_token_timestamps"] = True
|
generate_kwargs["return_token_timestamps"] = True
|
||||||
generate_kwargs["return_segments"] = True
|
generate_kwargs["return_segments"] = True
|
||||||
|
|
||||||
if stride is not None:
|
|
||||||
if isinstance(stride, tuple):
|
|
||||||
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
|
||||||
else:
|
|
||||||
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
|
||||||
else:
|
|
||||||
generate_kwargs["num_frames"] = num_frames
|
|
||||||
|
|
||||||
# User-defined `generation_config` passed to the pipeline call take precedence
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
||||||
if "generation_config" not in generate_kwargs:
|
if "generation_config" not in generate_kwargs:
|
||||||
generate_kwargs["generation_config"] = self.generation_config
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
@ -1793,7 +1793,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50364, 393, 4411, 13, 50514
|
[50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50364, 393, 4411, 13, 50514]
|
||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -2109,10 +2109,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200],
|
[0.0000, 0.8200, 0.9800, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 1.9800, 2.3400, 2.5000, 2.6600, 3.2000, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000],
|
||||||
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000],
|
[0.0000, 0.9000, 1.1400, 1.4200, 1.5200, 1.6600, 1.6600, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9400, 4.4000, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600],
|
||||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800],
|
[0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9400, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0800, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 16.6000, 16.6000],
|
||||||
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
|
[0.0000, 0.7400, 1.0400, 1.3000, 1.6800, 2.1200, 2.4800, 2.7600, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4000, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4000, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
|
||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -2139,10 +2139,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
[0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
[0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
[0.0000, 0.7600, 0.9800, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.2000, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||||
[0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600],
|
[0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600, 12.4600],
|
||||||
[0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
|
[0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1600, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
|
||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -2173,7 +2173,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# task id and lang id prompts should not have timestamp tokens
|
# task id and lang id prompts should not have timestamp tokens
|
||||||
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
|
|
||||||
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
|
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@ -2210,18 +2209,18 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
|
torch.tensor([0.0000, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5000, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600, 6.5400]),
|
||||||
torch.tensor([6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
|
torch.tensor([6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000, 11.2200]),
|
||||||
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
|
torch.tensor([11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1600, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800, 16.9800]),
|
||||||
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
|
torch.tensor([16.9800, 17.3200, 18.1800, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8400, 23.7000, 23.7000]),
|
||||||
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200, 29.9800]),
|
torch.tensor([23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.3800, 26.5800, 26.7600, 27.1600, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200, 29.9800, 29.9800]),
|
||||||
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
|
torch.tensor([29.4400, 29.7000, 30.0600, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.5000, 32.6200, 33.6800, 33.8000]),
|
||||||
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
|
torch.tensor([33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600, 40.5200]),
|
||||||
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
|
torch.tensor([40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000, 44.7000]),
|
||||||
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
|
torch.tensor([44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400, 50.5400]),
|
||||||
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
|
torch.tensor([50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400, 52.9600]),
|
||||||
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200, 59.4200]),
|
torch.tensor([52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1600, 58.5200, 58.6400, 58.8200, 59.4200, 59.4200]),
|
||||||
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
|
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.4200, 62.4200])
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@ -344,13 +344,8 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
model_outputs = [
|
model_outputs = [
|
||||||
{
|
{
|
||||||
'stride': [10, 0, 5],
|
'stride': [10, 0, 5],
|
||||||
'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]),
|
'tokens': np.array([[50363, 3363, 11, 345, 460, 0, 50423]]),
|
||||||
'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]])
|
'token_timestamps': np.array([[0.0, 0.5, 0.52, 0.78, 1.2, 1.28, 1.28]])
|
||||||
},
|
|
||||||
{
|
|
||||||
'stride': [10, 5, 0],
|
|
||||||
'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]),
|
|
||||||
'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]])
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -361,15 +356,12 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
EXPECTED_OUTPUT = (
|
EXPECTED_OUTPUT = (
|
||||||
" Yes, you can! Just do it",
|
" Yes, you can!",
|
||||||
{
|
{
|
||||||
"chunks": [
|
"chunks": [
|
||||||
{"text": " Yes,", "timestamp": (5.18, 5.56)},
|
{"text": " Yes,", "timestamp": (0.0, 0.52)},
|
||||||
{"text": " you", "timestamp": (5.56, 5.84)},
|
{"text": " you", "timestamp": (0.52, 0.78)},
|
||||||
{"text": " can!", "timestamp": (5.84, 7.12)},
|
{"text": " can!", "timestamp": (0.78, 1.28)},
|
||||||
{"text": " Just", "timestamp": (7.12, 7.56)},
|
|
||||||
{"text": " do", "timestamp": (7.56, 7.8)},
|
|
||||||
{"text": " it", "timestamp": (7.8, 8.72)},
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user