Cache: new Cache format in decoder-only models (#31421)

* draft bart with new cache

* add cache for decoder-only models

* revert utils

* modify docstring

* revert bart

* minor fixes

* fix copies (not related)

* revert tests

* remove enc-dec related code

* remove bloom

* remove opt (enc-dec)

* update docstring

* git, codegen, gpt_neo, gpt_neox, gpj

* clean up

* copied from statements

* revert

* tmp

* update warning msg

* forgot git

* add more flags

* run-slow git,codegen,gpt_neo,gpt_neox,gpj

* add cache flag to VLMs

* remove files

* style

* video LLMs also need a flag

* style

* llava will go in another PR

* style

* [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics

* Update src/transformers/models/gpt_neo/modeling_gpt_neo.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* copy from

* deprecate until v4.45 and warn if not training

* nit

* fix test

* test static cache

* add more tests and fix models

* fix copies

* return sliding window mask

* run slow tests & fix + codestyle

* one more falcon fix for alibi

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2024-08-07 10:02:16 +05:00 committed by GitHub
parent 6af0854efa
commit a30c865f99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1915 additions and 781 deletions

View File

@ -1016,7 +1016,9 @@ class StaticCache(Cache):
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
self.key_cache: List[torch.Tensor] = []

View File

@ -1473,7 +1473,7 @@ class GenerationMixin:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype
cache_dtype = self.get_output_embeddings().weight.dtype
cache_kwargs = {
"config": self.config,

View File

@ -22,6 +22,8 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
@ -34,6 +36,60 @@ _CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
_CONFIG_FOR_DOC = "CodeGenConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
@ -57,20 +113,19 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
class CodeGenAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
@ -114,27 +169,17 @@ class CodeGenAttention(nn.Module):
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / self.scale_attn
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights += causal_mask
attn_weights = attn_weights / self.scale_attn
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
@ -150,12 +195,13 @@ class CodeGenAttention(nn.Module):
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
@ -200,18 +246,16 @@ class CodeGenAttention(nn.Module):
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (key.to(hidden_states.dtype), value)
else:
present = None
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_dim,
"cache_position": cache_position,
}
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
@ -220,7 +264,7 @@ class CodeGenAttention(nn.Module):
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
@ -250,22 +294,23 @@ class CodeGenMLP(nn.Module):
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(nn.Module):
# Ignore copy
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = CodeGenAttention(config)
self.attn = CodeGenAttention(config, layer_idx)
self.mlp = CodeGenMLP(inner_dim, config)
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -277,6 +322,7 @@ class CodeGenBlock(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
@ -303,6 +349,9 @@ class CodeGenPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -374,6 +423,23 @@ CODEGEN_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -382,6 +448,10 @@ CODEGEN_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -397,7 +467,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
@ -421,7 +491,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -431,6 +501,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -439,85 +510,62 @@ class CodeGenModel(CodeGenPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
device = input_ids.device if input_ids is not None else inputs_embeds.device
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
seq_length = inputs_embeds.shape[1]
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1, seq_length, hidden_states.size(-1))
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -526,26 +574,28 @@ class CodeGenModel(CodeGenPreTrainedModel):
block.__call__,
hidden_states,
None,
attention_mask,
causal_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
next_decoder_cache = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -557,16 +607,94 @@ class CodeGenModel(CodeGenPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"""
@ -591,26 +719,31 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
@ -618,19 +751,45 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
)
return model_inputs
@ -644,7 +803,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -655,6 +814,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -676,6 +836,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]

View File

@ -24,10 +24,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...activations import get_activation
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@ -62,6 +61,60 @@ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
_CONFIG_FOR_DOC = "FalconConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
@ -244,7 +297,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__()
self.config = config
@ -257,6 +310,13 @@ class FalconAttention(nn.Module):
self.rope_theta = config.rope_theta
self.is_causal = True
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
@ -373,10 +433,11 @@ class FalconAttention(nn.Module):
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@ -391,25 +452,24 @@ class FalconAttention(nn.Module):
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, kv_length, head_dim]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=-2)
value_layer = torch.cat((past_value, value_layer), dim=-2)
cache_kwargs = {"cache_position": cache_position}
if alibi is None:
cache_kwargs.update({"sin": sin, "cos": cos})
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
kv_length = key_layer.shape[-2]
if use_cache:
present = (key_layer, value_layer)
else:
present = None
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -417,6 +477,9 @@ class FalconAttention(nn.Module):
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
if alibi is None:
if self._use_sdpa and not output_attentions:
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
@ -448,9 +511,9 @@ class FalconAttention(nn.Module):
attn_output = self.dense(attn_output)
if output_attentions:
return attn_output, present, attention_scores
return attn_output, layer_past, attention_scores
else:
return attn_output, present
return attn_output, layer_past
else:
if self._use_sdpa and not output_attentions and head_mask is None:
@ -502,9 +565,9 @@ class FalconAttention(nn.Module):
attn_output = self.dense(attn_output)
if output_attentions:
return attn_output, present, attention_probs
return attn_output, layer_past, attention_probs
else:
return attn_output, present
return attn_output, layer_past
class FalconFlashAttention2(FalconAttention):
@ -529,10 +592,11 @@ class FalconFlashAttention2(FalconAttention):
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@ -547,20 +611,22 @@ class FalconFlashAttention2(FalconAttention):
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
if layer_past is not None and use_cache:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size, self.num_heads, kv_length, head_dim]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=-2)
value_layer = torch.cat((past_value, value_layer), dim=-2)
past_key_value = (key_layer, value_layer) if use_cache else None
if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
if alibi is None:
cache_kwargs.update({"sin": sin, "cos": cos})
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
@ -614,7 +680,7 @@ class FalconFlashAttention2(FalconAttention):
if not output_attentions:
attn_weights = None
return attn_output, past_key_value, attn_weights
return attn_output, layer_past, attn_weights
class FalconMLP(nn.Module):
@ -641,12 +707,12 @@ FALCON_ATTENTION_CLASSES = {
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
@ -672,10 +738,11 @@ class FalconDecoderLayer(nn.Module):
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
residual = hidden_states
@ -696,6 +763,7 @@ class FalconDecoderLayer(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attention_output = attn_outputs[0]
@ -731,7 +799,7 @@ class FalconDecoderLayer(nn.Module):
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return outputs # hidden_states, past_kv, attentions
FALCON_START_DOCSTRING = r"""
@ -762,14 +830,23 @@ FALCON_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Each element of `past_key_values` is a tuple (past_key, past_value):
- past_key: [batch_size * num_heads, head_dim, kv_length]
- past_value: [batch_size * num_heads, kv_length, head_dim]
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -806,6 +883,10 @@ FALCON_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -821,6 +902,9 @@ class FalconPreTrainedModel(PreTrainedModel):
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -877,7 +961,7 @@ class FalconModel(FalconPreTrainedModel):
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
@ -904,7 +988,7 @@ class FalconModel(FalconPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
@ -913,6 +997,7 @@ class FalconModel(FalconPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -921,38 +1006,35 @@ class FalconModel(FalconPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-2]
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
alibi = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length, _ = inputs_embeds.shape
if self.use_alibi:
mask = (
torch.ones(
@ -961,67 +1043,32 @@ class FalconModel(FalconPreTrainedModel):
if attention_mask is None
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, head_mask, alibi
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = inputs_embeds
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -1030,28 +1077,30 @@ class FalconModel(FalconPreTrainedModel):
block.__call__,
hidden_states,
alibi,
attention_mask,
causal_mask,
position_ids,
head_mask[i],
layer_past,
past_key_values,
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
next_decoder_cache = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -1062,16 +1111,110 @@ class FalconModel(FalconPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
head_mask: torch.Tensor,
alibi: torch.Tensor,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
and head_mask is None
and alibi is None
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
batch_size, sequence_length, _ = input_tensor.shape
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
# We take care to integrate alibi bias in the causal_mask here
if head_mask is None and alibi is not None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
causal_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
causal_mask < -1,
min_dtype,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
@ -1097,23 +1240,22 @@ class FalconForCausalLM(FalconPreTrainedModel):
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Cache, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = True,
**kwargs,
) -> dict:
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
@ -1123,16 +1265,43 @@ class FalconForCausalLM(FalconPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and past_key_values is None:
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
@ -1147,7 +1316,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
@ -1157,6 +1326,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1178,6 +1348,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]

View File

@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...file_utils import ModelOutput
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
@ -124,13 +125,20 @@ class GitEmbeddings(nn.Module):
class GitSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config, position_embedding_type=None, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@ -161,46 +169,31 @@ class GitSelfAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
cutoff = self.image_patch_tokens if pixel_values_present else 0
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
if past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
value_layer = torch.cat(
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
key_layer_past, value_layer_past = past_key_value.update(
key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
past_key_value = (
key_layer[:, :, cutoff:, :],
value_layer[:, :, cutoff:, :],
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
if past_key_value is not None:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
@ -269,11 +262,10 @@ GIT_SELF_ATTENTION_CLASSES = {
class GitAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT
def __init__(self, config, position_embedding_type=None):
def __init__(self, config, position_embedding_type=None, layer_idx=None):
super().__init__()
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
)
self.output = GitSelfOutput(config)
self.pruned_heads = set()
@ -302,7 +294,7 @@ class GitAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
@ -351,11 +343,11 @@ class GitOutput(nn.Module):
class GitLayer(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = GitAttention(config)
self.attention = GitAttention(config, layer_idx=layer_idx)
self.intermediate = GitIntermediate(config)
self.output = GitOutput(config)
@ -364,18 +356,17 @@ class GitLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
past_key_value=past_key_value,
pixel_values_present=pixel_values_present,
)
attention_output = self_attention_outputs[0]
@ -401,11 +392,10 @@ class GitLayer(nn.Module):
class GitEncoder(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
@ -413,7 +403,7 @@ class GitEncoder(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
@ -427,16 +417,23 @@ class GitEncoder(nn.Module):
)
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
@ -444,7 +441,7 @@ class GitEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
past_key_value,
past_key_values,
output_attentions,
)
else:
@ -452,26 +449,30 @@ class GitEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
past_key_value,
past_key_values,
output_attentions,
pixel_values_present,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
next_decoder_cache = layer_outputs[-1]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
next_cache,
all_hidden_states,
all_self_attentions,
]
@ -479,7 +480,7 @@ class GitEncoder(nn.Module):
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@ -494,6 +495,8 @@ class GitPreTrainedModel(PreTrainedModel):
config_class = GitConfig
base_model_prefix = "git"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -569,6 +572,23 @@ GIT_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -1136,19 +1156,13 @@ class GitModel(GitPreTrainedModel):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -1195,7 +1209,13 @@ class GitModel(GitPreTrainedModel):
seq_length = input_shape[1]
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -1327,7 +1347,7 @@ class GitForCausalLM(GitPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -1338,12 +1358,6 @@ class GitForCausalLM(GitPreTrainedModel):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -1522,7 +1536,16 @@ class GitForCausalLM(GitPreTrainedModel):
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values.get_seq_length()
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
input_shape = input_ids.shape

View File

@ -23,7 +23,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
@ -68,6 +69,60 @@ _CONFIG_FOR_DOC = "GPTNeoConfig"
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
"""Load tf checkpoints in a pytorch model"""
try:
@ -149,7 +204,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
class GPTNeoSelfAttention(nn.Module):
def __init__(self, config, attention_type):
def __init__(self, config, attention_type, layer_id=None):
super().__init__()
self.config = config
@ -170,6 +225,7 @@ class GPTNeoSelfAttention(nn.Module):
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
self.is_causal = True
self.layer_id = layer_id
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
@ -208,6 +264,7 @@ class GPTNeoSelfAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
# Apply sliding window masking for local attention layers
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
@ -216,9 +273,9 @@ class GPTNeoSelfAttention(nn.Module):
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
@ -240,6 +297,7 @@ class GPTNeoSelfAttention(nn.Module):
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
@ -250,15 +308,8 @@ class GPTNeoSelfAttention(nn.Module):
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
cache_kwargs = {"cache_position": cache_position}
key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
@ -266,11 +317,11 @@ class GPTNeoSelfAttention(nn.Module):
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
return outputs # a, past_kv, (attentions)
class GPTNeoFlashAttention2(GPTNeoSelfAttention):
@ -297,6 +348,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
bsz, _, _ = hidden_states.size()
@ -309,15 +361,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
cache_kwargs = {"cache_position": cache_position}
key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
query_length = query.shape[2]
tgt_len = key.shape[2]
@ -330,6 +375,9 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
attn_dropout = self.config.attention_dropout if self.training else 0.0
if attention_mask is not None: # no matter the length, we just slice it
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
@ -371,7 +419,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
attn_output = self.out_proj(attn_weights_reshaped)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights_reshaped,)
@ -392,7 +440,9 @@ class GPTNeoAttention(nn.Module):
self.attention_type = self.attention_layers[layer_id]
if self.attention_type in ["global", "local"]:
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](
config, self.attention_type, layer_id
)
else:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
@ -407,6 +457,7 @@ class GPTNeoAttention(nn.Module):
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
return self.attention(
hidden_states,
@ -415,6 +466,7 @@ class GPTNeoAttention(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
@ -436,7 +488,7 @@ class GPTNeoMLP(nn.Module):
class GPTNeoBlock(nn.Module):
def __init__(self, config, layer_id):
def __init__(self, config, layer_id=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
@ -453,6 +505,7 @@ class GPTNeoBlock(nn.Module):
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -463,6 +516,7 @@ class GPTNeoBlock(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
@ -480,7 +534,7 @@ class GPTNeoBlock(nn.Module):
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
return outputs # hidden_states, past_kv, attentions
class GPTNeoPreTrainedModel(PreTrainedModel):
@ -496,6 +550,9 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO: needs a HybridCache
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -547,10 +604,23 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -595,6 +665,10 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -611,7 +685,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
@ -633,7 +706,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
@ -643,6 +716,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -651,58 +725,10 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
# Attention mask.
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_length)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training:
if use_cache:
@ -711,10 +737,51 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
)
use_cache = False
presents = () if use_cache else None
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
seq_length = inputs_embeds.shape[1]
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1, seq_length, hidden_states.size(-1))
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -723,24 +790,26 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
block.__call__,
hidden_states,
None,
attention_mask,
causal_mask,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=past_key_values,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if use_cache:
next_decoder_cache = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -752,16 +821,94 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"""
@ -787,26 +934,30 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
@ -814,22 +965,47 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@ -841,7 +1017,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
@ -852,6 +1028,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -873,6 +1050,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
@ -957,7 +1135,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
@ -1081,7 +1259,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,

View File

@ -23,13 +23,14 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -52,6 +53,60 @@ _REAL_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neox-20b"
_CONFIG_FOR_DOC = "GPTNeoXConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -64,6 +119,9 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_sdpa = True
def _init_weights(self, module):
@ -82,7 +140,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
class GPTNeoXAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads
@ -98,11 +156,18 @@ class GPTNeoXAttention(nn.Module):
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self._init_rope()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.norm_factor = self.head_size**-0.5
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.is_causal = True
self.layer_idx = layer_idx
def _init_bias(self, max_positions, device=None):
self.register_buffer(
@ -146,9 +211,11 @@ class GPTNeoXAttention(nn.Module):
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
):
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
@ -199,9 +266,8 @@ class GPTNeoXAttention(nn.Module):
position_ids: torch.LongTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
@ -225,22 +291,31 @@ class GPTNeoXAttention(nn.Module):
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
if layer_past is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
seq_len += layer_past.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
if layer_past is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_emb.dim,
"cache_position": cache_position,
}
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
return query, key, value, present
return query, key, value, layer_past
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
@ -277,9 +352,9 @@ class GPTNeoXAttention(nn.Module):
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_scores = attn_scores + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_scores = attn_scores + causal_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
@ -316,13 +391,18 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
):
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
cache_position=cache_position,
)
query_length = query.shape[-2]
@ -384,7 +464,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
@ -398,8 +478,8 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
to adapt to the SDPA API.
"""
def __init__(self, config):
super().__init__(config)
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx=layer_idx)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
@ -415,6 +495,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
):
if output_attentions or head_mask is not None:
logger.warning_once(
@ -431,15 +512,24 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
cache_position=cache_position,
)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
target_dtype = value.dtype
if query.dtype != target_dtype:
@ -455,13 +545,13 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 else False
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=is_causal,
)
@ -624,14 +714,14 @@ GPT_NEOX_ATTENTION_CLASSES = {
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = GPTNeoXMLP(config)
def forward(
@ -641,8 +731,9 @@ class GPTNeoXLayer(nn.Module):
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
@ -652,6 +743,7 @@ class GPTNeoXLayer(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
attn_output = self.post_attention_dropout(attn_output)
@ -722,6 +814,23 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -730,6 +839,10 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -744,7 +857,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList([GPTNeoXLayer(config, i) for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._attn_implementation = config._attn_implementation
@ -774,18 +887,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -797,50 +906,42 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
# Attention mask.
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
seq_length = inputs_embeds.shape[1]
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -848,20 +949,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.emb_dropout(inputs_embeds)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
next_decoder_cache = None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
for i, layer in enumerate(
self.layers,
):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -869,26 +964,28 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
causal_mask,
position_ids,
head_mask[i],
use_cache,
None,
output_attentions,
cache_position,
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=layer_past,
layer_past=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
next_decoder_cache = outputs[1]
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
@ -897,16 +994,92 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"""GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
@ -938,26 +1111,15 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
@ -997,6 +1159,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
@ -1024,24 +1187,27 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions=outputs.attentions,
)
# can't be copied from llama, gpt-neox has emebd_out and not lm_head
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
@ -1049,24 +1215,46 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.embed_out.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
"use_cache": kwargs.get("use_cache"),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
@ -1117,7 +1305,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -1229,7 +1417,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,

View File

@ -24,6 +24,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -55,6 +57,60 @@ _REAL_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B"
_CONFIG_FOR_DOC = "GPTJConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
@ -80,23 +136,22 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
class GPTJAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.is_causal = True
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
@ -152,27 +207,16 @@ class GPTJAttention(nn.Module):
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
@ -196,12 +240,13 @@ class GPTJAttention(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
@ -245,17 +290,13 @@ class GPTJAttention(nn.Module):
query = query.permute(0, 2, 1, 3)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_dim,
"cache_position": cache_position,
}
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
@ -264,7 +305,7 @@ class GPTJAttention(nn.Module):
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
@ -290,12 +331,13 @@ class GPTJFlashAttention2(GPTJAttention):
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
@ -343,17 +385,13 @@ class GPTJFlashAttention2(GPTJAttention):
# value: batch_size x num_attention_heads x seq_length x head_dim
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_dim,
"cache_position": cache_position,
}
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
# The Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
@ -412,7 +450,7 @@ class GPTJFlashAttention2(GPTJAttention):
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
@ -445,22 +483,23 @@ class GPTJMLP(nn.Module):
class GPTJBlock(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config)
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = GPTJMLP(inner_dim, config)
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -472,6 +511,7 @@ class GPTJBlock(nn.Module):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
@ -500,6 +540,9 @@ class GPTJPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_param_buffer_assignment = False
def __init__(self, *inputs, **kwargs):
@ -572,6 +615,23 @@ GPTJ_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
@ -580,6 +640,10 @@ GPTJ_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
PARALLELIZE_DOCSTRING = r"""
@ -643,7 +707,7 @@ class GPTJModel(GPTJPreTrainedModel):
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
self.h = nn.ModuleList([GPTJBlock(config, layer_idx=i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Model parallel
@ -714,7 +778,7 @@ class GPTJModel(GPTJPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -724,6 +788,7 @@ class GPTJModel(GPTJPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -732,73 +797,10 @@ class GPTJModel(GPTJPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if not self._use_flash_attention_2:
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training:
if use_cache:
@ -807,19 +809,64 @@ class GPTJModel(GPTJPreTrainedModel):
)
use_cache = False
presents = () if use_cache else None
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
seq_length = inputs_embeds.shape[1]
if cache_position is None:
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1, seq_length, hidden_states.size(-1))
next_decoder_cache = None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
if past_key_values is not None:
past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device)
past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if causal_mask is not None:
causal_mask = causal_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
@ -830,26 +877,28 @@ class GPTJModel(GPTJPreTrainedModel):
block.__call__,
hidden_states,
None,
attention_mask,
causal_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
cache_position,
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
next_decoder_cache = outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -867,16 +916,94 @@ class GPTJModel(GPTJPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"""
@ -936,26 +1063,31 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
@ -963,22 +1095,47 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -991,7 +1148,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -1002,6 +1159,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1023,6 +1181,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]

View File

@ -30,7 +30,8 @@ from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@ -50,6 +51,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "IdeficsConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@dataclass
class IdeficsBaseModelOutputWithPast(ModelOutput):
"""
@ -184,11 +239,13 @@ def expand_inputs_for_generation(
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
cache_position = kwargs.get("cache_position", None)
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@ -200,6 +257,9 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
pixel_values = kwargs.get("pixel_values", None)
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
@ -210,6 +270,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
@ -541,6 +602,7 @@ class IdeficsAttention(nn.Module):
is_cross_attention: bool = False,
config: PretrainedConfig = None,
qk_layer_norms: bool = False,
layer_idx: int = None,
):
super().__init__()
self.hidden_size = hidden_size
@ -549,6 +611,14 @@ class IdeficsAttention(nn.Module):
self.dropout = dropout
self.is_causal = True
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
@ -615,6 +685,7 @@ class IdeficsAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# if key_value_states are provided this layer is used as a cross-attention layer
is_cross_attention = self.is_cross_attention or key_value_states is not None
@ -634,18 +705,17 @@ class IdeficsAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
kv_seq_len += cache_position[0]
if not is_cross_attention:
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.qk_layer_norms:
query_states = self.q_layer_norm(query_states)
@ -700,7 +770,7 @@ class IdeficsAttention(nn.Module):
# this was adapted from LlamaDecoderLayer
class IdeficsDecoderLayer(nn.Module):
def __init__(self, config: IdeficsConfig):
def __init__(self, config: IdeficsConfig, layer_idx: int = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = IdeficsAttention(
@ -708,6 +778,7 @@ class IdeficsDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
dropout=config.dropout,
config=config,
layer_idx=layer_idx,
)
self.mlp = IdeficsMLP(
hidden_size=self.hidden_size,
@ -726,6 +797,7 @@ class IdeficsDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -753,6 +825,7 @@ class IdeficsDecoderLayer(nn.Module):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -944,6 +1017,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
@ -1031,6 +1105,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -1076,7 +1154,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
perceiver_config.resampler_n_latents,
)
self.layers = nn.ModuleList([IdeficsDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[IdeficsDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.cross_layer_interval = config.cross_layer_interval
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
@ -1132,6 +1212,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
device = input_ids.device if input_ids is not None else inputs_embeds.device
@ -1143,22 +1224,38 @@ class IdeficsModel(IdeficsPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
batch_size, seq_length, _ = inputs_embeds.shape
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
seq_length_with_past = seq_length + past_key_values_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
@ -1229,37 +1326,27 @@ class IdeficsModel(IdeficsPreTrainedModel):
device
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
def vblock(
main_block,
hidden_states,
@ -1274,6 +1361,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
layer_idx,
cross_layer_interval,
gated_cross_attn_layers,
cache_position,
):
# TODO(ls): Add cross attention values to respective lists
if layer_idx % cross_layer_interval == 0:
@ -1297,12 +1385,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
return layer_outputs
if self.gradient_checkpointing and self.training:
past_key_value = None
past_key_values = None
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@ -1315,7 +1404,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
hidden_states,
attention_mask,
position_ids,
past_key_value,
past_key_values,
image_hidden_states,
image_attention_mask,
cross_attention_gate,
@ -1324,6 +1413,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
idx,
self.cross_layer_interval,
self.gated_cross_attn_layers,
cache_position,
)
else:
layer_outputs = vblock(
@ -1331,7 +1421,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
cross_attention_gate=cross_attention_gate,
@ -1340,12 +1430,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
layer_idx=idx,
cross_layer_interval=self.cross_layer_interval,
gated_cross_attn_layers=self.gated_cross_attn_layers,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -1357,6 +1448,8 @@ class IdeficsModel(IdeficsPreTrainedModel):
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
if not return_dict:
return tuple(
@ -1372,6 +1465,78 @@ class IdeficsModel(IdeficsPreTrainedModel):
image_hidden_states=image_hidden_states,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
@ -1450,6 +1615,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
r"""
Args:
@ -1508,6 +1674,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
@ -1567,13 +1734,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
**kwargs,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder,
standardize_cache_format,
**kwargs,
)
if "image_attention_mask" in model_kwargs:

View File

@ -59,7 +59,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@ -1769,6 +1769,53 @@ class GenerationTesterMixin:
)
)
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
max_cache_len = seq_length + max_new_tokens
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:

View File

@ -4587,6 +4587,44 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.
This test does not compile the model and check only logits similarity for numerical precision
errors.
"""
if len(self.all_generative_model_classes) == 0:
self.skipTest(
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} does not support static cache")
if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} does not support cache class")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)
model.eval()
dynamic_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
)
static_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
cache_implementation="static",
output_logits=True,
return_dict_in_generate=True,
)
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu