[whisper] compile compatibility with long-form decoding (#31772)

* [whisper] compile compatibility with long-form decoding

* clarify comment

* fix after rebase

* finalise

* fix bsz

* fix cache split

* remove contiguous

* style

* finish

* update doc

* prevent cuda graph trace
This commit is contained in:
Sanchit Gandhi 2024-08-01 18:10:56 +08:00 committed by GitHub
parent 9451a38526
commit e234061cdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 156 additions and 15 deletions

View File

@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
Whisper is compatible with the following optimisations:
Whisper is compatible with the following optimisations for both short and long-form generation:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> for _ in range(2):
>>> model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)

View File

@ -126,12 +126,24 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att
def _pad_to_max_length(
current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None
current_segments,
pad_token_id,
device,
padding_side="right",
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
):
max_total_length = 0
sequences = []
if padding not in ["right", "left"]:
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")
if padding_side not in ["right", "left"]:
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
if padding not in ["longest", "max_length"]:
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
@ -150,9 +162,10 @@ def _pad_to_max_length(
else:
sequences.append(torch.tensor([], device=device))
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
sequences = torch.stack(sequences, dim=0)
@ -672,6 +685,7 @@ class WhisperGenerationMixin:
return_token_timestamps=return_token_timestamps,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
kwargs=kwargs,
)
@ -712,7 +726,7 @@ class WhisperGenerationMixin:
)
sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding="right"
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
@ -775,6 +789,7 @@ class WhisperGenerationMixin:
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
kwargs,
):
kwargs = copy.copy(kwargs)
@ -798,6 +813,22 @@ class WhisperGenerationMixin:
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]
cur_bsz = decoder_input_ids.shape[0]
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
decoder_input_ids = F.pad(
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
)
if generate_kwargs.get("decoder_attention_mask") is not None:
generate_kwargs["decoder_attention_mask"] = F.pad(
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
)
if generate_kwargs.get("encoder_outputs") is not None:
generate_kwargs["encoder_outputs"] = F.pad(
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
)
seek_outputs = super().generate(
segment_input,
generation_config=generation_config,
@ -820,6 +851,10 @@ class WhisperGenerationMixin:
is_shortform=is_shortform,
)
if cur_bsz < batch_size:
seek_sequences = seek_sequences[:cur_bsz]
seek_outputs = seek_outputs[:cur_bsz]
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
# Loop over each decoded audio individually as each decoding can be of a different length
new_fallback_index_map = []
@ -925,17 +960,27 @@ class WhisperGenerationMixin:
if not is_shortform:
# we don't save `past_key_values` as this is too costly for longform
return None
elif isinstance(values, EncoderDecoderCache):
all_past_key_values = []
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.key_cache, cache_cls.value_cache]:
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
else:
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
all_past_key_values = []
for v in range(len(values)):
layer_past_key_values = []
for w in values[v]:
layer_past_key_values.append(w[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()
seek_outputs = [
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
@ -1613,11 +1658,14 @@ class WhisperGenerationMixin:
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
device=device,
padding="left",
padding_side="left",
padding=padding,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)

View File

@ -1835,8 +1835,10 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
if decoder_position_ids is not None:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)
if cache_position is None:
cache_position = torch.arange(
@ -1845,6 +1847,36 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()
if (
isinstance(past_key_values, EncoderDecoderCache)
and (
isinstance(past_key_values.self_attention_cache, StaticCache)
or isinstance(past_key_values.cross_attention_cache, StaticCache)
)
and decoder_attention_mask is not None
and decoder_attention_mask.ndim == 2
):
batch_size, sequence_length = decoder_input_ids.shape
device = decoder_input_ids.device
dtype = self.proj_out.weight.dtype
min_dtype = torch.finfo(dtype).min
decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
decoder_attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.self_attention_cache.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,

View File

@ -3386,6 +3386,66 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
@slow
def test_tiny_static_generation_long_form(self):
import torch._dynamo.config
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]
inputs = processor(
input_speech,
return_tensors="pt",
padding="longest",
truncation=False,
return_attention_mask=True,
sampling_rate=16_000,
)
inputs = inputs.to(torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
"logprob_threshold": -1.0,
"num_beams": 1,
}
set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
# compile the forward pass and assert equivalence
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
input_features = inputs.input_features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
attention_mask = inputs.attention_mask[permutation_idx, ...]
set_seed(42)
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: