mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[whisper] compile compatibility with long-form decoding (#31772)
* [whisper] compile compatibility with long-form decoding * clarify comment * fix after rebase * finalise * fix bsz * fix cache split * remove contiguous * style * finish * update doc * prevent cuda graph trace
This commit is contained in:
parent
9451a38526
commit
e234061cdd
@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
Whisper is compatible with the following optimisations:
|
||||
Whisper is compatible with the following optimisations for both short and long-form generation:
|
||||
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
|
||||
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
|
||||
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
|
||||
@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up
|
||||
... ).input_features
|
||||
|
||||
>>> # Compile the forward pass
|
||||
>>> _ = model.generate(input_features)
|
||||
>>> for _ in range(2):
|
||||
>>> model.generate(input_features)
|
||||
|
||||
>>> # Generate token ids using compiled graph (fast!)
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
@ -126,12 +126,24 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att
|
||||
|
||||
|
||||
def _pad_to_max_length(
|
||||
current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None
|
||||
current_segments,
|
||||
pad_token_id,
|
||||
device,
|
||||
padding_side="right",
|
||||
padding="longest",
|
||||
bos_token_tensor=None,
|
||||
cut_off_length=None,
|
||||
):
|
||||
max_total_length = 0
|
||||
sequences = []
|
||||
if padding not in ["right", "left"]:
|
||||
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")
|
||||
|
||||
if padding_side not in ["right", "left"]:
|
||||
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
|
||||
|
||||
if padding not in ["longest", "max_length"]:
|
||||
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
|
||||
elif padding == "max_length" and cut_off_length is None:
|
||||
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
|
||||
|
||||
for current_segment_list in current_segments:
|
||||
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
||||
@ -150,9 +162,10 @@ def _pad_to_max_length(
|
||||
else:
|
||||
sequences.append(torch.tensor([], device=device))
|
||||
|
||||
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
|
||||
for i in range(len(current_segments)):
|
||||
pad_length = max_total_length - len(sequences[i])
|
||||
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
|
||||
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
|
||||
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
|
||||
|
||||
sequences = torch.stack(sequences, dim=0)
|
||||
@ -672,6 +685,7 @@ class WhisperGenerationMixin:
|
||||
return_token_timestamps=return_token_timestamps,
|
||||
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
|
||||
is_shortform=is_shortform,
|
||||
batch_size=batch_size,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
@ -712,7 +726,7 @@ class WhisperGenerationMixin:
|
||||
)
|
||||
|
||||
sequences = _pad_to_max_length(
|
||||
final_segments, generation_config.pad_token_id, device=self.device, padding="right"
|
||||
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
|
||||
)
|
||||
|
||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||
@ -775,6 +789,7 @@ class WhisperGenerationMixin:
|
||||
return_token_timestamps,
|
||||
do_condition_on_prev_tokens,
|
||||
is_shortform,
|
||||
batch_size,
|
||||
kwargs,
|
||||
):
|
||||
kwargs = copy.copy(kwargs)
|
||||
@ -798,6 +813,22 @@ class WhisperGenerationMixin:
|
||||
for key in ["do_sample", "temperature", "num_beams"]:
|
||||
if key in generate_kwargs:
|
||||
del generate_kwargs[key]
|
||||
|
||||
cur_bsz = decoder_input_ids.shape[0]
|
||||
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
|
||||
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
|
||||
decoder_input_ids = F.pad(
|
||||
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
|
||||
)
|
||||
if generate_kwargs.get("decoder_attention_mask") is not None:
|
||||
generate_kwargs["decoder_attention_mask"] = F.pad(
|
||||
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
|
||||
)
|
||||
if generate_kwargs.get("encoder_outputs") is not None:
|
||||
generate_kwargs["encoder_outputs"] = F.pad(
|
||||
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
|
||||
)
|
||||
|
||||
seek_outputs = super().generate(
|
||||
segment_input,
|
||||
generation_config=generation_config,
|
||||
@ -820,6 +851,10 @@ class WhisperGenerationMixin:
|
||||
is_shortform=is_shortform,
|
||||
)
|
||||
|
||||
if cur_bsz < batch_size:
|
||||
seek_sequences = seek_sequences[:cur_bsz]
|
||||
seek_outputs = seek_outputs[:cur_bsz]
|
||||
|
||||
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
|
||||
# Loop over each decoded audio individually as each decoding can be of a different length
|
||||
new_fallback_index_map = []
|
||||
@ -925,17 +960,27 @@ class WhisperGenerationMixin:
|
||||
if not is_shortform:
|
||||
# we don't save `past_key_values` as this is too costly for longform
|
||||
return None
|
||||
elif isinstance(values, EncoderDecoderCache):
|
||||
all_past_key_values = []
|
||||
for layer_idx in range(self.config.decoder_layers):
|
||||
layer_past_key_values = []
|
||||
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
|
||||
for v in [cache_cls.key_cache, cache_cls.value_cache]:
|
||||
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
|
||||
all_past_key_values.append(tuple(layer_past_key_values))
|
||||
return tuple(all_past_key_values)
|
||||
else:
|
||||
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
|
||||
all_past_key_values = []
|
||||
for v in range(len(values)):
|
||||
layer_past_key_values = []
|
||||
for w in values[v]:
|
||||
layer_past_key_values.append(w[batch_idx][None].cpu())
|
||||
all_past_key_values.append(tuple(layer_past_key_values))
|
||||
return tuple(all_past_key_values)
|
||||
|
||||
return values[batch_idx].cpu()
|
||||
|
||||
sequence_tokens = seek_outputs["sequences"]
|
||||
|
||||
if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
|
||||
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
|
||||
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()
|
||||
|
||||
seek_outputs = [
|
||||
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
|
||||
for i in range(sequence_tokens.shape[0])
|
||||
@ -1613,11 +1658,14 @@ class WhisperGenerationMixin:
|
||||
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
|
||||
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
|
||||
|
||||
padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
|
||||
|
||||
prev_tokens = _pad_to_max_length(
|
||||
active_segments,
|
||||
generation_config.pad_token_id,
|
||||
device=device,
|
||||
padding="left",
|
||||
padding_side="left",
|
||||
padding=padding,
|
||||
bos_token_tensor=prev_ids,
|
||||
cut_off_length=cut_off_length,
|
||||
)
|
||||
|
@ -1835,8 +1835,10 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||
if decoder_position_ids is not None:
|
||||
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
@ -1845,6 +1847,36 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-decoder_input_ids.shape[1] :]
|
||||
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
decoder_input_ids = decoder_input_ids.contiguous()
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, EncoderDecoderCache)
|
||||
and (
|
||||
isinstance(past_key_values.self_attention_cache, StaticCache)
|
||||
or isinstance(past_key_values.cross_attention_cache, StaticCache)
|
||||
)
|
||||
and decoder_attention_mask is not None
|
||||
and decoder_attention_mask.ndim == 2
|
||||
):
|
||||
batch_size, sequence_length = decoder_input_ids.shape
|
||||
device = decoder_input_ids.device
|
||||
|
||||
dtype = self.proj_out.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
decoder_attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.self_attention_cache.get_max_length(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
|
@ -3386,6 +3386,66 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
@slow
|
||||
def test_tiny_static_generation_long_form(self):
|
||||
import torch._dynamo.config
|
||||
|
||||
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
|
||||
torch._dynamo.config.cache_size_limit = 4
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
||||
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]
|
||||
|
||||
inputs = processor(
|
||||
input_speech,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16_000,
|
||||
)
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
gen_kwargs = {
|
||||
"return_timestamps": True,
|
||||
"no_speech_threshold": 0.6,
|
||||
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
"compression_ratio_threshold": 1.35,
|
||||
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
|
||||
"logprob_threshold": -1.0,
|
||||
"num_beams": 1,
|
||||
}
|
||||
|
||||
set_seed(42)
|
||||
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
assert (eager_generated_ids == static_generated_ids).all()
|
||||
|
||||
# check the compiled graph can be re-used and that the cache is correctly reset
|
||||
# reverse the ordering of the input features
|
||||
input_features = inputs.input_features
|
||||
permutation_idx = (
|
||||
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
|
||||
)
|
||||
input_features = input_features[permutation_idx, ...]
|
||||
attention_mask = inputs.attention_mask[permutation_idx, ...]
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
Loading…
Reference in New Issue
Block a user