mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Whisper: fix static cache CI (#35852)
* fix * remove overriden method * small change
This commit is contained in:
parent
9725e5be2f
commit
365fecb4d0
@ -406,23 +406,28 @@ class GenerationMixin:
|
||||
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# 4. Create missing `position_ids` on the fly
|
||||
attention_mask = (
|
||||
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
|
||||
)
|
||||
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
|
||||
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
|
||||
if (
|
||||
attention_mask is not None
|
||||
and kwargs.get("position_ids") is None
|
||||
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
|
||||
and kwargs.get(position_ids_key) is None
|
||||
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
|
||||
):
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
||||
kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)
|
||||
|
||||
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
||||
for model_input_name in ["position_ids", "token_type_ids"]:
|
||||
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
|
||||
model_input = kwargs.get(model_input_name)
|
||||
if model_input is not None:
|
||||
if past_key_values is not None:
|
||||
current_input_length = (
|
||||
model_inputs["inputs_embeds"].shape[1]
|
||||
if model_inputs["inputs_embeds"] is not None
|
||||
if model_inputs.get("inputs_embeds") is not None
|
||||
else model_inputs[input_ids_key].shape[1]
|
||||
)
|
||||
model_input = model_input[:, -current_input_length:]
|
||||
@ -469,7 +474,7 @@ class GenerationMixin:
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
model_inputs[attention_mask_key] = attention_mask
|
||||
|
||||
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
||||
for key, value in kwargs.items():
|
||||
|
@ -1234,7 +1234,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
|
||||
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
|
||||
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
|
||||
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
|
||||
set_inputs({"inputs": segment_input, "input_ids": decoder_input_ids, **extra_kwargs})
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
|
||||
|
@ -1255,7 +1255,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1)
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
@ -1806,88 +1806,6 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time
|
||||
# this function needs to be touched, let's try to sort out the commonalities between the two and remove the
|
||||
# overwrite.
|
||||
|
||||
decoder_position_ids = None
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
|
||||
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
else:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
if decoder_position_ids is not None:
|
||||
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
|
||||
)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-decoder_input_ids.shape[1] :]
|
||||
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
decoder_input_ids = decoder_input_ids.contiguous()
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, EncoderDecoderCache)
|
||||
and (
|
||||
isinstance(past_key_values.self_attention_cache, StaticCache)
|
||||
or isinstance(past_key_values.cross_attention_cache, StaticCache)
|
||||
)
|
||||
and decoder_attention_mask is not None
|
||||
and decoder_attention_mask.ndim == 2
|
||||
):
|
||||
batch_size, sequence_length = decoder_input_ids.shape
|
||||
|
||||
decoder_attention_mask = self.get_decoder()._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
decoder_attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.self_attention_cache.get_max_cache_shape(),
|
||||
dtype=self.proj_out.weight.dtype,
|
||||
device=decoder_input_ids.device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"use_cache": use_cache,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_position_ids": decoder_position_ids,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
|
||||
|
||||
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
||||
"""
|
||||
|
@ -3323,8 +3323,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = input_features.to(torch_device)
|
||||
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
|
||||
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
@ -3379,9 +3379,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
set_seed(42)
|
||||
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user