mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Cache: init empty cache when use_cache
(#34274)
* fix
* fix tests
* fix copies
* add docs
* Revert "add docs"
This reverts commit 32d35634f1
.
* qwen move deltas
* mllama can potentiall fullgraph compile
* enable mllama compile and fix tests
* remove mllama fixes
This commit is contained in:
parent
1339a14dca
commit
c1a8520419
@ -25,7 +25,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
@ -1300,6 +1300,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
|
@ -24,7 +24,7 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
@ -1618,6 +1618,9 @@ class MllamaTextModel(MllamaPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
@ -1845,7 +1848,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
||||
super().__init__(config.get_text_config())
|
||||
self.text_config = config.get_text_config()
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
self.model = MllamaTextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
|
||||
self.model = MllamaTextModel._from_config(self.text_config)
|
||||
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
@ -780,6 +780,9 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -30,7 +30,7 @@ import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
@ -549,10 +549,6 @@ class Qwen2VLAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += cache_position[0] + 1
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
@ -646,16 +642,6 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
@ -784,9 +770,6 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
@ -1116,6 +1099,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -1428,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
self.model = Qwen2VLModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
||||
self.rope_deltas = None # cache rope_deltas here
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1507,7 +1494,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
mrope_position_deltas = []
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
@ -1600,25 +1587,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
num_new_tokens: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
num_new_tokens=num_new_tokens,
|
||||
)
|
||||
|
||||
if getattr(outputs, "rope_deltas", None) is not None:
|
||||
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1638,6 +1606,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1726,8 +1695,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if position_ids is None and input_ids is not None:
|
||||
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
@ -1739,6 +1724,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1769,7 +1755,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=rope_deltas,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
@ -1798,22 +1784,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
rope_deltas = kwargs.get("rope_deltas", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
else:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
delta = (
|
||||
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=input_ids.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
if cache_position[0] != 0:
|
||||
pixel_values = None
|
||||
pixel_values_videos = None
|
||||
@ -1854,7 +1824,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
"video_grid_thw": video_grid_thw,
|
||||
"rope_deltas": rope_deltas,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
@ -1531,6 +1531,14 @@ class GenerationTesterMixin:
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
# some models have diffent num-head for query vs key/value so we need to assign correct value
|
||||
# BUT only after `per_head_embed_dim` is set
|
||||
num_attention_heads = (
|
||||
text_config.num_key_value_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||
else num_attention_heads
|
||||
)
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
|
||||
|
@ -333,6 +333,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -2343,7 +2343,8 @@ class ModelTesterMixin:
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
# model might return non-tensors objects (e.g. Cache class)
|
||||
elif isinstance(tuple_object, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
|
Loading…
Reference in New Issue
Block a user