mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Whisper] 🚨 Fix whisper decoding 🚨 (#34135)
* do not remove decoder_input_ids for the first segment * do not remove eos token in generate_with_fallback * when removing padding tokens, do not remove eos token * remove eos token in generate (and not in generate_with_fallback!) * reconciliate short-from/ long-form behavior * correct avg_logprobs calculation * handle eos token in segments * handle decoder_input_ids and eos token in _prepare_decoder_input_ids * fix incorrect time precision * always remove eos token * always remove decoder_input_ids * no need to handle decoder_inputs_ids and eos token * no need to remove decoder_input_ids * no need to handle eos token * fix num_beams in _retrieve_logit_processors * remove todo unconsistency * no need to add eos token * last_timestamp_pos should indeed be timestamp token pos * patch generate to enable compatibility with GenerationTesterMixin tests * adapt test_generate_continue_from_past_key_values * adapt test_prompt_lookup_decoding_matches_greedy_search * adapt generic GenerationMixin tests to whisper's generate * fix speculative decoding * fix * [run-slow] whisper * change HF_HUB_TOKEN for require_read_token * [run-slow] whisper * prioritize kwargs over generation_config * remove unnecessary args * [run-slow] whisper * update tests * [run-slow] whisper * add comment * update test * [run-slow] whisper * update test + revert require_read_token * docstring updates * revert tokenizer decode args change * do not use a patch + docstring updates * [run-slow] whisper * make * [run-slow] whisper * add a flag to force unique call to generate * test update * [run-slow] whisper * add force_unique_generate_call arg * do not use a patch * correct the timestamps for the pad tokens * docstring update * docstring update * docstring update * upodate TF tests * add require_read_token * [run-slow] whisper * test reset dynamo * [run-slow] whisper * fix * [run-slow] whisper * avoid iterating twice on current_segments * [run-slow] whisper * [run-slow] whisper --------- Co-authored-by: Eustache Le Bihan <eustlb@users.noreply.huggingface.co> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f1b7634fc8
commit
da334bcfa8
@ -133,9 +133,12 @@ def _pad_to_max_length(
|
||||
padding="longest",
|
||||
bos_token_tensor=None,
|
||||
cut_off_length=None,
|
||||
return_token_timestamps=False,
|
||||
force_unique_generate_call=False,
|
||||
):
|
||||
max_total_length = 0
|
||||
sequences = []
|
||||
token_timestamps_list = []
|
||||
|
||||
if padding_side not in ["right", "left"]:
|
||||
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
||||
@ -145,31 +148,74 @@ def _pad_to_max_length(
|
||||
elif padding == "max_length" and cut_off_length is None:
|
||||
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
||||
|
||||
if force_unique_generate_call:
|
||||
sequences_list = []
|
||||
timestamps_list = []
|
||||
for segments in current_segments:
|
||||
result = segments[0]["result"]
|
||||
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
|
||||
if return_token_timestamps:
|
||||
timestamps_list.append(result["token_timestamps"])
|
||||
|
||||
sequences = torch.stack(sequences_list, dim=0)
|
||||
if return_token_timestamps:
|
||||
token_timestamps = torch.stack(timestamps_list, dim=0)
|
||||
return sequences, token_timestamps
|
||||
return sequences
|
||||
|
||||
for current_segment_list in current_segments:
|
||||
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
||||
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
||||
if return_token_timestamps:
|
||||
token_timestamps = torch.cat(
|
||||
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if cut_off_length is not None:
|
||||
sequence = sequence[-cut_off_length:]
|
||||
if return_token_timestamps:
|
||||
token_timestamps = token_timestamps[-cut_off_length:]
|
||||
|
||||
if bos_token_tensor is not None:
|
||||
sequence = torch.cat([bos_token_tensor, sequence])
|
||||
|
||||
if return_token_timestamps:
|
||||
token_timestamps = torch.cat(
|
||||
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
|
||||
)
|
||||
sequences.append(sequence)
|
||||
if return_token_timestamps:
|
||||
token_timestamps_list.append(token_timestamps)
|
||||
max_total_length = max(max_total_length, len(sequences[-1]))
|
||||
elif bos_token_tensor is not None:
|
||||
sequences.append(bos_token_tensor)
|
||||
if return_token_timestamps:
|
||||
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
|
||||
else:
|
||||
sequences.append(torch.tensor([], device=device))
|
||||
if return_token_timestamps:
|
||||
token_timestamps_list.append(torch.tensor([], device=device))
|
||||
|
||||
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
||||
for i in range(len(current_segments)):
|
||||
pad_length = max_total_length - len(sequences[i])
|
||||
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
||||
|
||||
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
||||
if return_token_timestamps:
|
||||
token_timestamps_list[i] = F.pad(
|
||||
token_timestamps_list[i],
|
||||
pad=pad,
|
||||
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
|
||||
)
|
||||
|
||||
sequences = torch.stack(sequences, dim=0)
|
||||
return sequences
|
||||
|
||||
if return_token_timestamps:
|
||||
token_timestamps = torch.stack(token_timestamps_list, dim=0)
|
||||
return sequences, token_timestamps
|
||||
else:
|
||||
return sequences
|
||||
|
||||
|
||||
class WhisperGenerationMixin(GenerationMixin):
|
||||
@ -312,6 +358,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
return_token_timestamps: Optional[bool] = None,
|
||||
return_segments: bool = False,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
force_unique_generate_call: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -432,27 +479,39 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
|
||||
`return_segments` is set True. In this case the generation outputs of each segment is added to each
|
||||
segment.
|
||||
force_unique_generate_call (`bool`, *optional*):
|
||||
Whether to force a unique call to the underlying GenerationMixin's generate method. This is useful for assisted decoding and testing purposes to ensure
|
||||
that only one call to generate is made and therefore decoder input token ids and eos token ids are returned.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
|
||||
[`~utils.ModelOutput`] or `Dict[str, Any]` or `torch.LongTensor`:
|
||||
|
||||
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
|
||||
A:
|
||||
- [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
|
||||
- `Dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
|
||||
- `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
|
||||
|
||||
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
|
||||
The possible [`~utils.ModelOutput`] types are:
|
||||
- [`~utils.GenerateEncoderDecoderOutput`]
|
||||
- [`~utils.GenerateBeamEncoderDecoderOutput`]
|
||||
|
||||
- [`~generation.GenerateEncoderDecoderOutput`],
|
||||
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
||||
`segments` is a list of lists (one list per batch element) of `segment`.
|
||||
A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`.
|
||||
- `start`: the start timestamp of the segment.
|
||||
- `end`: the end timestamp of the segment.
|
||||
- `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id.
|
||||
- `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's `generate` (present in `result`).
|
||||
- `result`: the result of the underlying call to GenerationMixin's `generate`.
|
||||
|
||||
else only the generated output sequence ids are returned.
|
||||
When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's `generate`, with outputs stored in `result` of each `segment`.
|
||||
|
||||
Example:
|
||||
|
||||
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
|
||||
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. It is necessary to set `return_timestamps=True`.
|
||||
Indeed, long-form transcription uses a sequential algorithm based on timestamps predictions, with heuristics like compression ratio threshold, log probability threshold and temperature fallback. This algorithm is described in the [the Whisper original paper](https://cdn.openai.com/papers/whisper.pdf), section *3.8. Long-form Transcription*.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
@ -483,7 +542,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
|
||||
```
|
||||
|
||||
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
|
||||
- *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
|
||||
- `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's generate.
|
||||
- `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
@ -570,11 +631,21 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
# 3. Retrieve logits processors
|
||||
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
||||
begin_index = init_tokens.shape[1]
|
||||
num_beams = kwargs.get(
|
||||
"num_beams",
|
||||
generation_config.num_beams
|
||||
if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
|
||||
else 1,
|
||||
)
|
||||
if "assistant_model" in kwargs:
|
||||
# speculative decoding: the model should be able to return eos token
|
||||
generation_config.begin_suppress_tokens = None
|
||||
|
||||
logits_processor = self._retrieve_logit_processors(
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
begin_index=begin_index, # begin index is index of first generated decoder token
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
num_beams=num_beams,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -618,6 +689,19 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
batch_size=cur_bsz,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
# 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
|
||||
# we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
|
||||
if "assistant_model" in kwargs:
|
||||
assistant_model = kwargs["assistant_model"]
|
||||
assistant_model.generation_config.force_unique_generate_call = True
|
||||
|
||||
if force_unique_generate_call is None:
|
||||
if hasattr(generation_config, "force_unique_generate_call"):
|
||||
force_unique_generate_call = generation_config.force_unique_generate_call
|
||||
elif hasattr(self.generation_config, "force_unique_generate_call"):
|
||||
force_unique_generate_call = self.generation_config.force_unique_generate_call
|
||||
else:
|
||||
force_unique_generate_call = False
|
||||
|
||||
# 6 Transcribe audio until we reach the end of all input audios
|
||||
while (seek < max_frames).any():
|
||||
@ -729,14 +813,15 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
prev_idx=prev_i,
|
||||
idx=i,
|
||||
return_token_timestamps=return_token_timestamps,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
)
|
||||
|
||||
seek[prev_i] += segment_offset
|
||||
|
||||
current_segments[prev_i] += segments
|
||||
|
||||
if is_shortform:
|
||||
seek[prev_i] += max_frames[i]
|
||||
else:
|
||||
seek[prev_i] += segment_offset
|
||||
if force_unique_generate_call:
|
||||
break
|
||||
|
||||
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
||||
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
||||
@ -746,51 +831,62 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
else current_segments
|
||||
)
|
||||
|
||||
sequences = _pad_to_max_length(
|
||||
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
|
||||
)
|
||||
|
||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||
if return_segments:
|
||||
return {"sequences": sequences, "segments": final_segments}
|
||||
|
||||
if is_shortform:
|
||||
# add eos token:
|
||||
if generation_config.max_new_tokens is None and generation_config.max_length is None:
|
||||
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
|
||||
sequences = torch.cat([sequences, eos_tokens], dim=-1)
|
||||
|
||||
if return_token_timestamps:
|
||||
outputs = {}
|
||||
outputs["sequences"] = sequences
|
||||
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
|
||||
else:
|
||||
outputs = sequences
|
||||
|
||||
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
||||
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
|
||||
|
||||
if num_return_sequences > 1:
|
||||
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
|
||||
dict_outputs.encoder_attentions = tuple(
|
||||
dict_outputs.encoder_attentions[i][::num_return_sequences]
|
||||
for i in range(len(dict_outputs.encoder_attentions))
|
||||
)
|
||||
if (
|
||||
hasattr(dict_outputs, "encoder_hidden_states")
|
||||
and dict_outputs.encoder_hidden_states is not None
|
||||
):
|
||||
dict_outputs.encoder_hidden_states = tuple(
|
||||
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
|
||||
for i in range(len(dict_outputs.encoder_hidden_states))
|
||||
)
|
||||
if return_token_timestamps:
|
||||
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
|
||||
return dict_outputs
|
||||
|
||||
# if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
|
||||
# -> we can return a ModelOutput
|
||||
# otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
|
||||
if (
|
||||
return_dict_in_generate
|
||||
and generation_config.return_dict_in_generate
|
||||
and (force_unique_generate_call or not return_timestamps)
|
||||
):
|
||||
# only one call to generate_with_fallback, we can return a ModelOutput
|
||||
outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
|
||||
if num_return_sequences > 1:
|
||||
if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
|
||||
outputs.encoder_attentions = tuple(
|
||||
outputs.encoder_attentions[i][::num_return_sequences]
|
||||
for i in range(len(outputs.encoder_attentions))
|
||||
)
|
||||
if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
|
||||
outputs.encoder_hidden_states = tuple(
|
||||
outputs.encoder_hidden_states[i][::num_return_sequences]
|
||||
for i in range(len(outputs.encoder_hidden_states))
|
||||
)
|
||||
return outputs
|
||||
|
||||
return sequences
|
||||
padded_outputs = _pad_to_max_length(
|
||||
current_segments=final_segments,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
device=self.device,
|
||||
padding_side="right",
|
||||
return_token_timestamps=return_token_timestamps,
|
||||
force_unique_generate_call=force_unique_generate_call,
|
||||
)
|
||||
|
||||
if return_dict_in_generate and generation_config.return_dict_in_generate:
|
||||
logger.warning_once(
|
||||
"You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
|
||||
)
|
||||
return_segments = True
|
||||
elif not return_segments and not return_token_timestamps:
|
||||
return padded_outputs
|
||||
|
||||
if return_token_timestamps:
|
||||
sequences, token_timestamps = padded_outputs
|
||||
outputs = {
|
||||
"sequences": sequences,
|
||||
"token_timestamps": token_timestamps,
|
||||
}
|
||||
else:
|
||||
sequences = padded_outputs
|
||||
outputs = {
|
||||
"sequences": sequences,
|
||||
}
|
||||
|
||||
if return_segments:
|
||||
outputs["segments"] = final_segments
|
||||
|
||||
return outputs
|
||||
|
||||
def generate_with_fallback(
|
||||
self,
|
||||
@ -886,22 +982,14 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
new_decoder_attention_mask = []
|
||||
|
||||
for i, seek_sequence in enumerate(seek_sequences):
|
||||
# make sure we cut a predicted EOS token if we are not finished with the generation yet
|
||||
prev_i = batch_idx_map[fallback_index_map[i]]
|
||||
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
|
||||
|
||||
# remove eos token id
|
||||
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
|
||||
seek_sequence = seek_sequence[:-1]
|
||||
if return_token_timestamps and not is_shortform:
|
||||
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
|
||||
|
||||
# remove all padding tokens
|
||||
# remove all padding tokens, except for the eos token
|
||||
if seek_sequence[-1] == generation_config.pad_token_id:
|
||||
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
|
||||
seek_sequence = seek_sequence[:-num_paddings]
|
||||
if return_token_timestamps and not is_shortform:
|
||||
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
|
||||
if generation_config.pad_token_id == generation_config.eos_token_id:
|
||||
# we do not remove the eos token id since it is needed for avg logprob calculation in _need_fallback
|
||||
num_paddings -= 1
|
||||
if num_paddings != 0:
|
||||
seek_sequence = seek_sequence[:-num_paddings]
|
||||
|
||||
# check which sequences in batch need fallback & which should be skipped
|
||||
needs_fallback[i], should_skip[i] = self._need_fallback(
|
||||
@ -914,6 +1002,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
temperature,
|
||||
)
|
||||
|
||||
# remove eos token
|
||||
if seek_sequence[-1] == generation_config.eos_token_id:
|
||||
seek_sequence = seek_sequence[:-1]
|
||||
|
||||
seek_sequence_list[fallback_index_map[i]] = seek_sequence
|
||||
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
|
||||
is_low_temperature = temperature is None or temperature < 0.5
|
||||
@ -956,14 +1048,19 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
return current_segments
|
||||
|
||||
def _postprocess_outputs(
|
||||
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
|
||||
self,
|
||||
seek_outputs,
|
||||
decoder_input_ids,
|
||||
return_token_timestamps,
|
||||
generation_config,
|
||||
is_shortform,
|
||||
):
|
||||
# remove all previously passed decoder input ids
|
||||
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
|
||||
# should happen only if it is the first generated segment
|
||||
start_idx = decoder_input_ids.shape[-1]
|
||||
|
||||
if isinstance(seek_outputs, torch.Tensor):
|
||||
seek_outputs = seek_outputs[:, start_idx:]
|
||||
return seek_outputs, seek_outputs
|
||||
return seek_outputs[:, start_idx:], seek_outputs
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
num_frames = getattr(generation_config, "num_frames", None)
|
||||
@ -973,9 +1070,6 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
num_frames=num_frames,
|
||||
num_input_ids=decoder_input_ids.shape[-1],
|
||||
)
|
||||
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
||||
|
||||
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
||||
|
||||
def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
|
||||
if beam_indices is not None and key == "scores":
|
||||
@ -1011,7 +1105,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
|
||||
return values[batch_idx].cpu()
|
||||
|
||||
sequence_tokens = seek_outputs["sequences"]
|
||||
sequence_tokens = seek_outputs["sequences"][:, start_idx:]
|
||||
seek_outputs = [
|
||||
{
|
||||
k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
|
||||
@ -1026,7 +1120,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
|
||||
outputs = {}
|
||||
for key in seek_outputs[0].keys():
|
||||
if key in ["sequences", "beam_indices"]:
|
||||
if key in ["sequences", "beam_indices", "token_timestamps"]:
|
||||
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
|
||||
elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
|
||||
outputs[key] = tuple(
|
||||
@ -1057,6 +1151,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
else:
|
||||
outputs[key] = None
|
||||
|
||||
token_timestamps = outputs.get("token_timestamps", None)
|
||||
if token_timestamps is not None:
|
||||
model_output_type = dict
|
||||
|
||||
return model_output_type(**outputs)
|
||||
|
||||
def _need_fallback(
|
||||
@ -1083,7 +1181,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
else:
|
||||
scores = seek_outputs[index]["scores"]
|
||||
logprobs = self._retrieve_avg_logprobs(
|
||||
scores, seek_sequence, generation_config.eos_token_id, temperature
|
||||
scores,
|
||||
seek_sequence,
|
||||
temperature,
|
||||
)
|
||||
|
||||
if logprobs < generation_config.logprob_threshold:
|
||||
@ -1179,13 +1279,6 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
if no_speech_threshold is not None:
|
||||
logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
|
||||
|
||||
# when passing temperature as a list it cannot just be ignored => throw error in this case
|
||||
if isinstance(temperature, (list, tuple)):
|
||||
raise ValueError(
|
||||
f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
|
||||
f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
|
||||
if return_dict_in_generate is None:
|
||||
@ -1768,7 +1861,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
return compression_ratio
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
|
||||
def _retrieve_avg_logprobs(scores, tokens, temperature):
|
||||
rescale_temperature = temperature if temperature > 0.0 else 1
|
||||
scores = torch.stack(scores).to(tokens.device)
|
||||
|
||||
@ -1780,10 +1873,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
|
||||
|
||||
# retrieve logprob of selected tokens and sum
|
||||
sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
|
||||
length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
|
||||
# don't remove the eos token logprob! it counts in avg_logprob calculation in the original implementation
|
||||
sum_logprobs = sum(logprobs[i][tokens[i]] for i in range(logprobs.shape[0]))
|
||||
|
||||
avg_logprobs = sum_logprobs / (length + 1)
|
||||
avg_logprobs = sum_logprobs / len(tokens)
|
||||
return avg_logprobs
|
||||
|
||||
@staticmethod
|
||||
@ -1799,6 +1892,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
prev_idx,
|
||||
idx,
|
||||
return_token_timestamps,
|
||||
decoder_input_ids,
|
||||
):
|
||||
# find the predicted "end of segment" predictions of Whisper
|
||||
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
||||
@ -1807,6 +1901,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
timestamp_segment_indices.add_(1)
|
||||
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
|
||||
idx_offset = decoder_input_ids.shape[-1]
|
||||
device = seek_sequence.device
|
||||
|
||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
||||
@ -1838,12 +1933,13 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
|
||||
* time_precision,
|
||||
"tokens": sliced_tokens,
|
||||
"idxs": (idx_offset + last_slice, idx_offset + current_slice),
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
)
|
||||
if return_token_timestamps:
|
||||
segments[-1]["token_timestamps"] = (
|
||||
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
|
||||
token_timestamps[idx_offset + last_slice : idx_offset + current_slice] + time_offset[prev_idx]
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
@ -1871,11 +1967,14 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
"start": time_offset[prev_idx],
|
||||
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
|
||||
"tokens": seek_sequence,
|
||||
"idxs": (idx_offset, idx_offset + len(seek_sequence)),
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
]
|
||||
if return_token_timestamps:
|
||||
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
|
||||
segments[-1]["token_timestamps"] = (
|
||||
token_timestamps[idx_offset : idx_offset + len(seek_sequence)] + time_offset[prev_idx]
|
||||
)
|
||||
segment_offset = seek_num_frames[prev_idx]
|
||||
|
||||
return segments, segment_offset
|
||||
|
@ -17,14 +17,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
|
||||
from transformers import GenerationConfig, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import (
|
||||
is_tf_available,
|
||||
require_read_token,
|
||||
require_tf,
|
||||
require_tokenizers,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
@ -749,7 +757,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@ -772,13 +782,29 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
ds = load_dataset("legacy-datasets/common_voice", "ja", split="test", streaming=True, trust_remote_code=True)
|
||||
# update generation config
|
||||
generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||
|
||||
token = os.getenv("HF_HUB_READ_TOKEN", True)
|
||||
ds = load_dataset(
|
||||
"mozilla-foundation/common_voice_6_1",
|
||||
"ja",
|
||||
split="test",
|
||||
streaming=True,
|
||||
trust_remote_code=True,
|
||||
token=token,
|
||||
)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|ja|>",
|
||||
task="transcribe",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@ -786,7 +812,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|en|>",
|
||||
task="transcribe",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@ -794,7 +825,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|ja|>",
|
||||
task="translate",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@ -825,10 +861,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_IDS = [
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
[50258, 50259, 50359, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50359, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50359, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50359, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@ -836,10 +872,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@ -1009,6 +1045,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_large_generation_multilingual(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
|
||||
|
||||
|
@ -445,6 +445,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
self.maxDiff = 3000
|
||||
|
||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||
config, inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
|
||||
inputs_dict["force_unique_generate_call"] = True
|
||||
return config, inputs_dict
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@ -1891,8 +1896,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
"ja",
|
||||
split="test",
|
||||
streaming=True,
|
||||
token=token,
|
||||
trust_remote_code=True,
|
||||
token=token,
|
||||
)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
|
||||
@ -2144,11 +2149,16 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||
"timestamp": (39.80, 45.36),
|
||||
# "timestamp": (39.80, 45.36),
|
||||
# above is the expected output on A100.
|
||||
# on CI T4s, due to sligth difference in floating points operations, expected is below
|
||||
"timestamp": (39.80, 45.38),
|
||||
},
|
||||
{
|
||||
"text": " can discover in it but little of rocky Ithaca.",
|
||||
"timestamp": (45.36, 49.0),
|
||||
# "timestamp": (45.36, 49.0),
|
||||
# see above
|
||||
"timestamp": (45.38, 49.0),
|
||||
},
|
||||
{
|
||||
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||
@ -2275,20 +2285,20 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400],
|
||||
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400],
|
||||
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200],
|
||||
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000],
|
||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800],
|
||||
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600]
|
||||
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_large_token_timestamp_generation(self):
|
||||
def test_small_token_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
@ -2305,10 +2315,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||
[0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||
[0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000],
|
||||
[0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800]
|
||||
[0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||
[0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600],
|
||||
[0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
@ -3331,6 +3341,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
|
||||
torch._dynamo.config.cache_size_limit = 4
|
||||
torch._dynamo.reset()
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
Loading…
Reference in New Issue
Block a user