Cache: init empty cache when use_cache (#34274)

* fix

* fix tests

* fix copies

* add docs

* Revert "add docs"

This reverts commit 32d35634f1.

* qwen move deltas

* mllama can potentiall fullgraph compile

* enable mllama compile and fix tests

* remove mllama fixes
This commit is contained in:
Raushan Turganbay 2024-11-25 10:11:33 +01:00 committed by GitHub
parent 1339a14dca
commit c1a8520419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 57 additions and 64 deletions

View File

@ -25,7 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -1300,6 +1300,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(

View File

@ -24,7 +24,7 @@ from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
@ -1618,6 +1618,9 @@ class MllamaTextModel(MllamaPreTrainedModel):
hidden_states = inputs_embeds
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
@ -1845,7 +1848,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
super().__init__(config.get_text_config())
self.text_config = config.get_text_config()
self.vocab_size = self.text_config.vocab_size
self.model = MllamaTextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
self.model = MllamaTextModel._from_config(self.text_config)
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
self.post_init()

View File

@ -780,6 +780,9 @@ class NemotronModel(NemotronPreTrainedModel):
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

View File

@ -21,7 +21,7 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN
from ...cache_utils import Cache, SlidingWindowCache, StaticCache
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
@ -549,10 +549,6 @@ class Qwen2VLAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += cache_position[0] + 1
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@ -646,16 +642,6 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
logger.warning_once(
@ -784,9 +770,6 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@ -1116,6 +1099,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
)
use_cache = False
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -1428,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
@ -1507,7 +1494,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
@ -1600,25 +1587,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return position_ids, mrope_position_deltas
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas
return model_kwargs
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1638,6 +1606,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
@ -1726,8 +1695,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if position_ids is None and input_ids is not None:
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
@ -1739,6 +1724,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
@ -1769,7 +1755,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
rope_deltas=self.rope_deltas,
)
def prepare_inputs_for_generation(
@ -1798,22 +1784,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
rope_deltas = kwargs.get("rope_deltas", None)
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
@ -1854,7 +1824,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"rope_deltas": rope_deltas,
"cache_position": cache_position,
}
)
return model_inputs

View File

@ -1531,6 +1531,14 @@ class GenerationTesterMixin:
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
# some models have diffent num-head for query vs key/value so we need to assign correct value
# BUT only after `per_head_embed_dim` is set
num_attention_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_attention_heads
)
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)

View File

@ -333,6 +333,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
def test_generate_compile_fullgraph(self):
pass
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -2343,7 +2343,8 @@ class ModelTesterMixin:
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
# model might return non-tensors objects (e.g. Cache class)
elif isinstance(tuple_object, torch.Tensor):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5