diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 00726d82cce..68c52c6eb3c 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -252,6 +252,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): Specifies the device for computation of the log-mel spectrogram of audio signals in the `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") return_token_timestamps (`bool`, *optional*, defaults to `None`): + Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred. + Whether or not to return the number of frames of the input raw_speech. These num_frames can be used by the model to compute word level timestamps. """ @@ -327,6 +329,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] if return_token_timestamps is not None: + logger.warning_once( + f"`return_token_timestamps` is deprecated for {self.__class__.__name__} and will be removed in Transformers v5. Use `return_attention_mask` instead, as the number of frames can be inferred from it." + ) padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] if return_tensors is not None: diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 2a64e599d06..248d17cac40 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -331,6 +331,11 @@ class WhisperGenerationMixin(GenerationMixin): num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames num_frames = np.repeat(num_frames, repeat_time) + # let's ignore decoder_input_ids that can negatively impact the DTW while we know they have timestamps 0.0s + # (they are not taken into account for the DTW in OAI implementation) + if num_input_ids is not None: + weights = weights[:, :, num_input_ids:, :] + if num_frames is None or isinstance(num_frames, int): # Normalize and smoothen the weights. std = torch.std(weights, dim=-2, keepdim=True, unbiased=False) @@ -360,7 +365,13 @@ class WhisperGenerationMixin(GenerationMixin): text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy()) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jump_times = time_indices[jumps] * time_precision - timestamps[batch_idx, 1:] = torch.tensor(jump_times) + + # each predicted token has a corresponding timestamp, expect the eos token for which we don't retrieve cross attentions + # 1. for decoder_input_ids, we set the timestamps to 0.0 + # 2. for the eos token, we simply duplicate the timestamp of the last non-eos token + timestamps[batch_idx] = torch.cat( + [torch.zeros(num_input_ids), torch.tensor(jump_times), torch.tensor([jump_times[-1]])] + ) return timestamps @@ -632,7 +643,10 @@ class WhisperGenerationMixin(GenerationMixin): language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config ) self._set_num_frames( - return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs + return_token_timestamps=return_token_timestamps, + generation_config=generation_config, + attention_mask=attention_mask, + kwargs=kwargs, ) self._set_thresholds_and_condition( generation_config=generation_config, @@ -810,10 +824,8 @@ class WhisperGenerationMixin(GenerationMixin): segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, - batch_idx_map=batch_idx_map, seek=seek, - num_segment_frames=num_segment_frames, - max_frames=max_frames, + batch_idx_map=batch_idx_map, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, @@ -928,10 +940,8 @@ class WhisperGenerationMixin(GenerationMixin): segment_input, decoder_input_ids, cur_bsz, - batch_idx_map, seek, - num_segment_frames, - max_frames, + batch_idx_map, temperatures, generation_config, logits_processor, @@ -1003,6 +1013,8 @@ class WhisperGenerationMixin(GenerationMixin): return_token_timestamps=return_token_timestamps, generation_config=generation_config, is_shortform=is_shortform, + seek=seek, + batch_idx_map=batch_idx_map, ) if cur_bsz < batch_size: @@ -1089,6 +1101,8 @@ class WhisperGenerationMixin(GenerationMixin): return_token_timestamps, generation_config, is_shortform, + seek, + batch_idx_map, ): # remove all previously passed decoder input ids # should happen only if it is the first generated segment @@ -1098,7 +1112,11 @@ class WhisperGenerationMixin(GenerationMixin): return seek_outputs[:, start_idx:], seek_outputs if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) + num_frames = getattr(generation_config, "num_frames") + if num_frames is not None: + num_frames = num_frames - seek + num_frames = num_frames[batch_idx_map] + seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs, generation_config.alignment_heads, @@ -1634,7 +1652,7 @@ class WhisperGenerationMixin(GenerationMixin): ) @staticmethod - def _set_num_frames(return_token_timestamps, generation_config, kwargs): + def _set_num_frames(return_token_timestamps, generation_config, attention_mask, kwargs): if return_token_timestamps: if getattr(generation_config, "task", None) == "translate": logger.warning("Token-level timestamps may not be reliable for task 'translate'.") @@ -1643,7 +1661,24 @@ class WhisperGenerationMixin(GenerationMixin): "Model generation config has no `alignment_heads`, token-level timestamps not available. " "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." ) - generation_config.num_frames = kwargs.pop("num_frames", None) + if "num_frames" in kwargs: + generation_config.num_frames = kwargs.pop("num_frames") + if isinstance(generation_config.num_frames, torch.Tensor): + generation_config.num_frames = generation_config.num_frames.cpu() + else: + generation_config.num_frames = torch.tensor(generation_config.num_frames) + + logger.warning_once( + "`num_frames` is deprecated and will be removed in Transformers v5. Use `attention_mask` instead, as it can be used to infer the number of frames. " + "You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True" + ) + elif attention_mask is not None: + generation_config.num_frames = attention_mask.sum(-1).cpu() + else: + logger.warning_once( + "When setting `return_token_timestamps` to `True`, make sure to pass an `attention_mask` to get precise token-level timestamps. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + ) + generation_config.num_frames = None @staticmethod def _set_thresholds_and_condition( diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 44f8a745fd0..2d9dd6845c4 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1099,11 +1099,11 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, # merges later and decode into text. current_tokens.append(token) if return_timestamps == "word": - start_time = round(token_timestamps[i] + time_offset, 2) - if i + 1 < len(token_timestamps): - end_time = round(token_timestamps[i + 1] + time_offset, 2) + if i == 0: + start_time = round(0.0 + time_offset, 2) else: - end_time = None # should never happen + start_time = round(token_timestamps[i - 1] + time_offset, 2) + end_time = round(token_timestamps[i] + time_offset, 2) current_token_timestamps.append((start_time, end_time)) if "stride" in output: diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index e8b4af94c72..232ef4463b4 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -495,19 +495,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # custom processing for Whisper timestamps and word-level timestamps return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False) if return_timestamps and self.type == "seq2seq_whisper": - generate_kwargs["return_timestamps"] = return_timestamps + generate_kwargs["return_timestamps"] = bool(return_timestamps) if return_timestamps == "word": generate_kwargs["return_token_timestamps"] = True generate_kwargs["return_segments"] = True - if stride is not None: - if isinstance(stride, tuple): - generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length - else: - generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] - else: - generate_kwargs["num_frames"] = num_frames - # User-defined `generation_config` passed to the pipeline call take precedence if "generation_config" not in generate_kwargs: generate_kwargs["generation_config"] = self.generation_config diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 860ec88b847..7888d7bab8b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1793,7 +1793,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_OUTPUT = torch.tensor([ - 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50364, 393, 4411, 13, 50514 + [50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50364, 393, 4411, 13, 50514] ]) # fmt: on @@ -2109,10 +2109,10 @@ class WhisperModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_OUTPUT = torch.tensor([ - [0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200], - [0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000], - [0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800], - [0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200] + [0.0000, 0.8200, 0.9800, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 1.9800, 2.3400, 2.5000, 2.6600, 3.2000, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000, 11.9000], + [0.0000, 0.9000, 1.1400, 1.4200, 1.5200, 1.6600, 1.6600, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9400, 4.4000, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600, 17.9600], + [0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9400, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0800, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 16.6000, 16.6000], + [0.0000, 0.7400, 1.0400, 1.3000, 1.6800, 2.1200, 2.4800, 2.7600, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4000, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4000, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200] ]) # fmt: on @@ -2139,10 +2139,10 @@ class WhisperModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_OUTPUT = torch.tensor([ - [0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], - [0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], - [0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600], - [0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800] + [0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], + [0.0000, 0.7600, 0.9800, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.2000, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], + [0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600, 12.4600], + [0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1600, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800] ]) # fmt: on @@ -2173,7 +2173,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): ) # task id and lang id prompts should not have timestamp tokens - self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1]) self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples) @slow @@ -2210,18 +2209,18 @@ class WhisperModelIntegrationTests(unittest.TestCase): # fmt: off EXPECTED_OUTPUT = [ - torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]), - torch.tensor([6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]), - torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]), - torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]), - torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200, 29.9800]), - torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]), - torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]), - torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]), - torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]), - torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]), - torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200, 59.4200]), - torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]), + torch.tensor([0.0000, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5000, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600, 6.5400]), + torch.tensor([6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000, 11.2200]), + torch.tensor([11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1600, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800, 16.9800]), + torch.tensor([16.9800, 17.3200, 18.1800, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8400, 23.7000, 23.7000]), + torch.tensor([23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.3800, 26.5800, 26.7600, 27.1600, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200, 29.9800, 29.9800]), + torch.tensor([29.4400, 29.7000, 30.0600, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.5000, 32.6200, 33.6800, 33.8000]), + torch.tensor([33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600, 40.5200]), + torch.tensor([40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000, 44.7000]), + torch.tensor([44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400, 50.5400]), + torch.tensor([50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400, 52.9600]), + torch.tensor([52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1600, 58.5200, 58.6400, 58.8200, 59.4200, 59.4200]), + torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.4200, 62.4200]) ] # fmt: on diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 40fed6d76fb..f31d7da0554 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -344,13 +344,8 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): model_outputs = [ { 'stride': [10, 0, 5], - 'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]), - 'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]]) - }, - { - 'stride': [10, 5, 0], - 'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]), - 'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]]) + 'tokens': np.array([[50363, 3363, 11, 345, 460, 0, 50423]]), + 'token_timestamps': np.array([[0.0, 0.5, 0.52, 0.78, 1.2, 1.28, 1.28]]) } ] # fmt: on @@ -361,15 +356,12 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ) EXPECTED_OUTPUT = ( - " Yes, you can! Just do it", + " Yes, you can!", { "chunks": [ - {"text": " Yes,", "timestamp": (5.18, 5.56)}, - {"text": " you", "timestamp": (5.56, 5.84)}, - {"text": " can!", "timestamp": (5.84, 7.12)}, - {"text": " Just", "timestamp": (7.12, 7.56)}, - {"text": " do", "timestamp": (7.56, 7.8)}, - {"text": " it", "timestamp": (7.8, 8.72)}, + {"text": " Yes,", "timestamp": (0.0, 0.52)}, + {"text": " you", "timestamp": (0.52, 0.78)}, + {"text": " can!", "timestamp": (0.78, 1.28)}, ] }, )