mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix static generation when compiling! (#28937)
* wow I was scared! * fix everything * nits * make it BC? * add todo * nits * is_tracing should still be used to pass tracing tests * nits * some nits to make sure genration works with static cache uncompiled * fix sdpa * fix FA2 for both static and dynamic in a better way? * style * fix-copies * fix fix copies * fix sequential beam searcg * style * use `keys_to_ignore` * nit * correct dtype inference when init * :( the fix for FA2 is still not optimal to investigate! * styling * nits * nit * this might work better * add comment * Update src/transformers/models/llama/modeling_llama.py * "position_ids" -> "cache_position" * style * nit * Remove changes that should no be propagatted just yet * Apply suggestions from code review * Styling * make sure we raise an errir for static cache with FA2 enabled * move to the bottom of the signature * style * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py * nit in the name --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
609a1767e8
commit
f3788b09e1
@ -344,17 +344,15 @@ class StaticCache(Cache):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
|
||||
) -> None:
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
|
||||
|
||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
@ -386,20 +384,23 @@ class StaticCache(Cache):
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
new_cache_positions = cache_kwargs.get("position_ids")
|
||||
new_cache_positions = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache
|
||||
v_out = self.value_cache
|
||||
|
||||
k_out[:, :, new_cache_positions] = key_states
|
||||
v_out[:, :, new_cache_positions] = value_states
|
||||
|
||||
self.seen_tokens += key_states.shape[-2]
|
||||
self.seen_tokens += key_states.shape[2]
|
||||
return k_out, v_out
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
|
||||
return self.seen_tokens
|
||||
|
||||
def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
|
||||
return self.seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
||||
return self.max_cache_len
|
||||
|
@ -4776,8 +4776,9 @@ def _split_model_inputs(
|
||||
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
|
||||
# ModelOutput object.
|
||||
# bool should not be split but replicated for each split
|
||||
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
|
||||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
|
||||
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
|
||||
keys_to_ignore = ["cache_position", "encoder_outputs"]
|
||||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
|
||||
|
||||
# we split the tensors and tuples of tensors
|
||||
data_split_list = [
|
||||
|
@ -29,7 +29,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -303,6 +303,7 @@ class LlamaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -333,21 +334,13 @@ class LlamaAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
@ -356,7 +349,8 @@ class LlamaAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
@ -410,6 +404,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
@ -427,20 +422,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
@ -603,6 +592,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -617,6 +607,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -629,29 +620,22 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and cache_position is not None:
|
||||
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
@ -666,7 +650,6 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
@ -703,6 +686,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -736,6 +720,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -800,13 +785,20 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
||||
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
||||
raise ValueError(
|
||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
||||
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
for layer in self.model.layers:
|
||||
weights = layer.self_attn.o_proj.weight
|
||||
layer.self_attn.past_key_value = cache_cls(
|
||||
self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device
|
||||
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
|
||||
)
|
||||
|
||||
def _reset_cache(self):
|
||||
@ -932,6 +924,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -951,12 +944,23 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
|
||||
|
||||
# embed positions
|
||||
@ -980,6 +984,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -989,6 +994,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1021,8 +1027,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
|
||||
def _update_causal_mask(self, attention_mask, input_tensor):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
return causal_mask
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
batch_size, seq_length = input_tensor.shape[:2]
|
||||
dtype = input_tensor.dtype
|
||||
@ -1051,14 +1058,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
if attention_mask is None:
|
||||
return None
|
||||
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
|
||||
if not is_tracing and (torch.all(attention_mask == 1)):
|
||||
return None
|
||||
if is_tracing and seq_length == 1:
|
||||
return None
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype)
|
||||
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@ -1107,6 +1111,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1150,6 +1155,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1189,6 +1195,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
@ -1228,9 +1235,17 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
past_length = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, past_length:]
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||
# same goes for position ids. Could also help with continued generation.
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
@ -1241,6 +1256,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
|
@ -823,7 +823,6 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
@ -864,12 +863,6 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
@ -1084,7 +1084,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
@ -1125,12 +1125,6 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
@ -1048,7 +1048,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
@ -1089,12 +1088,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
Loading…
Reference in New Issue
Block a user